From 06f31994dc1360c865efbd024d506922df6be1c4 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 22 Apr 2025 14:03:37 -0700 Subject: [PATCH] Fix performance issue introduced in PR: 499 (#505) 1. use `fence+relaxed` to replace `release` for fifo. `fence+relax` is more efficient on A100 2. Update the deviceSyncer. Previous one cannot handle threadBlock number change correctly. Use three counters to solve this issue. Reset previous counter before sync on current counter. 3. Introduce relaxedWait which can be used with relaxedSignal for case doesn't need guarantee the memory visibility --- include/mscclpp/concurrency_device.hpp | 18 +++++++++++------- include/mscclpp/fifo_device.hpp | 6 ++++++ include/mscclpp/memory_channel_device.hpp | 8 ++++++++ include/mscclpp/semaphore_device.hpp | 15 +++++++++++++++ python/mscclpp_benchmark/allreduce.cu | 17 ++++------------- test/deploy/perf_ndmv4.jsonl | 2 +- 6 files changed, 45 insertions(+), 21 deletions(-) diff --git a/include/mscclpp/concurrency_device.hpp b/include/mscclpp/concurrency_device.hpp index e0ca9d94..e1c055a3 100644 --- a/include/mscclpp/concurrency_device.hpp +++ b/include/mscclpp/concurrency_device.hpp @@ -7,6 +7,8 @@ #include "atomic_device.hpp" #include "poll_device.hpp" +#define NUM_DEVICE_SYNCER_COUNTER 3 + namespace mscclpp { /// A device-wide barrier. @@ -24,15 +26,17 @@ struct DeviceSyncer { /// @param blockNum The number of blocks that will synchronize. /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. MSCCLPP_DEVICE_INLINE void sync(int blockNum, int64_t maxSpinCount = 100000000) { - int targetCnt = blockNum; + unsigned int targetCnt = blockNum; __syncthreads(); if (blockNum == 1) return; if (threadIdx.x == 0) { - unsigned int tmp = preFlag_ ^ 1; - int val = (tmp << 1) - 1; - targetCnt = val == 1 ? targetCnt : 0; - atomicFetchAdd(&count_, val, memoryOrderRelease); - POLL_MAYBE_JAILBREAK((atomicLoad(&count_, memoryOrderAcquire) != targetCnt), maxSpinCount); + unsigned int tmp = (preFlag_ + 1) % NUM_DEVICE_SYNCER_COUNTER; + unsigned int next = (tmp + 1) % NUM_DEVICE_SYNCER_COUNTER; + unsigned int* count = &count_[tmp]; + count_[next] = 0; + atomicFetchAdd(count, 1U, memoryOrderRelease); + POLL_MAYBE_JAILBREAK((atomicLoad(count, memoryOrderAcquire) != targetCnt), + maxSpinCount); preFlag_ = tmp; } // We need sync here because only a single thread is checking whether @@ -43,7 +47,7 @@ struct DeviceSyncer { private: /// The counter of synchronized blocks. - int count_; + unsigned int count_[NUM_DEVICE_SYNCER_COUNTER]; /// The flag to indicate whether to increase or decrease @ref flag_. unsigned int preFlag_; }; diff --git a/include/mscclpp/fifo_device.hpp b/include/mscclpp/fifo_device.hpp index 03ca6339..f431b1d8 100644 --- a/include/mscclpp/fifo_device.hpp +++ b/include/mscclpp/fifo_device.hpp @@ -67,7 +67,13 @@ struct FifoDeviceHandle { // Make sure the data is visible to the host before we update the tail. #if defined(MSCCLPP_DEVICE_CUDA) +#if __CUDA_ARCH__ == 800 + // For A100, threadfence_system is more efficient than release + __threadfence_system(); + asm volatile("st.global.relaxed.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd)); +#else asm volatile("st.global.release.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd)); +#endif #else // !defined(MSCCLPP_DEVICE_CUDA) // store snd no later than fst. atomicStore(&(triggerPtr->snd), trigger.snd, memoryOrderRelaxed); diff --git a/include/mscclpp/memory_channel_device.hpp b/include/mscclpp/memory_channel_device.hpp index d49eb4de..1cf6a853 100644 --- a/include/mscclpp/memory_channel_device.hpp +++ b/include/mscclpp/memory_channel_device.hpp @@ -273,6 +273,14 @@ struct MemoryChannelDeviceHandle { /// Wait for the remote semaphore to send a signal. /// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative. MSCCLPP_DEVICE_INLINE void wait(int64_t maxSpinCount = 10000000) { semaphore_.wait(maxSpinCount); } + + /// Wait for the remote semaphore to send a signal. + /// + /// This function is a relaxed version of signal() and provides no guarantee on the completion of memory operations. + /// User requires to call proper fencing before using this function. + /// + /// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative. + MSCCLPP_DEVICE_INLINE void relaxedWait() { semaphore_.relaxedWait(); } #endif // defined(MSCCLPP_DEVICE_COMPILE) }; diff --git a/include/mscclpp/semaphore_device.hpp b/include/mscclpp/semaphore_device.hpp index 31dff14d..88ca9882 100644 --- a/include/mscclpp/semaphore_device.hpp +++ b/include/mscclpp/semaphore_device.hpp @@ -54,6 +54,17 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle { POLL_MAYBE_JAILBREAK((atomicLoad(inboundSemaphoreId, memoryOrderAcquire) < flag), maxSpinCount); } + /// Wait for the remote device to signal. + /// + /// This function is a relaxed version of Wait() and provides no guarantee on the completion of memory operations. + /// User requires to call proper fencing before using this function. + /// + MSCCLPP_DEVICE_INLINE void relaxedWait(int64_t maxSpinCount = 100000000) { + (*expectedInboundSemaphoreId) += 1; + uint64_t flag = (*expectedInboundSemaphoreId); + POLL_MAYBE_JAILBREAK((atomicLoad(inboundSemaphoreId, memoryOrderRelaxed) < flag), maxSpinCount); + } + /// Signal the remote device. /// /// This function guarantees that all the memory operation before this function is completed before the remote @@ -65,7 +76,11 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle { semaphoreIncrement(); // use memoryOrderSeqCst instead of memoryOrderRelease since memoryOrderSeqCst // is more efficient on A100. +#if __CUDA_ARCH__ == 800 atomicStore(remoteInboundSemaphoreId, semaphoreGetLocal(), memoryOrderSeqCst); +#else + atomicStore(remoteInboundSemaphoreId, semaphoreGetLocal(), memoryOrderRelease); +#endif } /// Signal the remote device. diff --git a/python/mscclpp_benchmark/allreduce.cu b/python/mscclpp_benchmark/allreduce.cu index dbe376a3..ce124da5 100644 --- a/python/mscclpp_benchmark/allreduce.cu +++ b/python/mscclpp_benchmark/allreduce.cu @@ -134,16 +134,11 @@ __device__ void allreduce1_helper(mscclpp::MemoryChannelDeviceHandle* memChans, int4* buff4 = (int4*)buff; const int tid = threadIdx.x + blockIdx.x * blockDim.x; - // synchronize everyone - if (tid == 0) { - __threadfence_system(); - } - __syncthreads(); if (tid < nPeer) { memChans[tid].relaxedSignal(); } if (tid >= nPeer && tid < nPeer * 2) { - memChans[tid - nPeer].wait(); + memChans[tid - nPeer].relaxedWait(); } deviceSyncer.sync(gridDim.x); @@ -193,15 +188,11 @@ __device__ void allreduce1_helper(mscclpp::MemoryChannelDeviceHandle* memChans, // synchronize everyone again deviceSyncer.sync(gridDim.x); - if (tid == 0) { - __threadfence_system(); - } - __syncthreads(); if (tid < nPeer) { memChans[tid].relaxedSignal(); } if (tid >= nPeer && tid < nPeer * 2) { - memChans[tid - nPeer].wait(); + memChans[tid - nPeer].relaxedWait(); } if (READ_ONLY) { @@ -437,7 +428,7 @@ __device__ void localReduceScatterMem(mscclpp::MemoryChannelDeviceHandle* memCha memChans[peerIdx].relaxedSignal(); } for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) { - memChans[peerIdx].wait(); + memChans[peerIdx].relaxedWait(); } reduceScatterDeviceSyncer.sync(nBlocks); @@ -497,7 +488,7 @@ __device__ void localAllGatherMem(mscclpp::MemoryChannelDeviceHandle* memChans, } if (threadIdx.x == 0 && peerLocalBlockIdx == 0) { memChans[peerIdx].relaxedSignal(); - memChans[peerIdx].wait(); + memChans[peerIdx].relaxedWait(); } allGatherDeviceSyncer.sync(nBlocks); size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk; diff --git a/test/deploy/perf_ndmv4.jsonl b/test/deploy/perf_ndmv4.jsonl index 75799885..ed8bc17a 100644 --- a/test/deploy/perf_ndmv4.jsonl +++ b/test/deploy/perf_ndmv4.jsonl @@ -3,7 +3,7 @@ {"name":"allgather", "kernel":3, "ranks":8, "ranksPerNode":8, "algBw":0.1112, "busBw":0.0973, "size":8192, "time":73.63, "target":"latency"} {"name":"allreduce", "kernel":1, "ranks":8, "ranksPerNode":8, "algBw":139.41, "busBw":243.96, "size":1073741824, "time":7701.98, "target":"throughput"} {"name":"allreduce", "kernel":2, "ranks":8, "ranksPerNode":8, "algBw":1.25, "busBw":2.19, "size":8192, "time":6.51, "target":"latency"} -{"name":"allreduce", "kernel":2, "ranks":16,"ranksPerNode":8, "algBw":0.51, "busBw":0.96, "size":8192, "time":15.96, "target":"latency"} +{"name":"allreduce", "kernel":2, "ranks":16,"ranksPerNode":8, "algBw":0.44, "busBw":0.83, "size":8192, "time":18.42, "target":"latency"} {"name":"allreduce", "kernel":3, "ranks":8, "ranksPerNode":8, "algBw":139.08, "busBw":243.40, "size":1073741824, "time":7719.85, "target":"throughput"} {"name":"allreduce", "kernel":4, "ranks":8, "ranksPerNode":8, "algBw":106.98, "busBw":187.22, "size":16777216, "time":156.81, "target":"throughput"} {"name":"allreduce", "kernel":4, "ranks":8, "ranksPerNode":8, "algBw":116.24, "busBw":203.42, "size":33554432, "time":288.65, "target":"throughput"}