diff --git a/include/mscclpp/concurrency_device.hpp b/include/mscclpp/concurrency_device.hpp index e0ca9d94..e1c055a3 100644 --- a/include/mscclpp/concurrency_device.hpp +++ b/include/mscclpp/concurrency_device.hpp @@ -7,6 +7,8 @@ #include "atomic_device.hpp" #include "poll_device.hpp" +#define NUM_DEVICE_SYNCER_COUNTER 3 + namespace mscclpp { /// A device-wide barrier. @@ -24,15 +26,17 @@ struct DeviceSyncer { /// @param blockNum The number of blocks that will synchronize. /// @param maxSpinCount The maximum number of spin counts before asserting. Never assert if negative. MSCCLPP_DEVICE_INLINE void sync(int blockNum, int64_t maxSpinCount = 100000000) { - int targetCnt = blockNum; + unsigned int targetCnt = blockNum; __syncthreads(); if (blockNum == 1) return; if (threadIdx.x == 0) { - unsigned int tmp = preFlag_ ^ 1; - int val = (tmp << 1) - 1; - targetCnt = val == 1 ? targetCnt : 0; - atomicFetchAdd(&count_, val, memoryOrderRelease); - POLL_MAYBE_JAILBREAK((atomicLoad(&count_, memoryOrderAcquire) != targetCnt), maxSpinCount); + unsigned int tmp = (preFlag_ + 1) % NUM_DEVICE_SYNCER_COUNTER; + unsigned int next = (tmp + 1) % NUM_DEVICE_SYNCER_COUNTER; + unsigned int* count = &count_[tmp]; + count_[next] = 0; + atomicFetchAdd(count, 1U, memoryOrderRelease); + POLL_MAYBE_JAILBREAK((atomicLoad(count, memoryOrderAcquire) != targetCnt), + maxSpinCount); preFlag_ = tmp; } // We need sync here because only a single thread is checking whether @@ -43,7 +47,7 @@ struct DeviceSyncer { private: /// The counter of synchronized blocks. - int count_; + unsigned int count_[NUM_DEVICE_SYNCER_COUNTER]; /// The flag to indicate whether to increase or decrease @ref flag_. unsigned int preFlag_; }; diff --git a/include/mscclpp/fifo_device.hpp b/include/mscclpp/fifo_device.hpp index 03ca6339..f431b1d8 100644 --- a/include/mscclpp/fifo_device.hpp +++ b/include/mscclpp/fifo_device.hpp @@ -67,7 +67,13 @@ struct FifoDeviceHandle { // Make sure the data is visible to the host before we update the tail. #if defined(MSCCLPP_DEVICE_CUDA) +#if __CUDA_ARCH__ == 800 + // For A100, threadfence_system is more efficient than release + __threadfence_system(); + asm volatile("st.global.relaxed.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd)); +#else asm volatile("st.global.release.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd)); +#endif #else // !defined(MSCCLPP_DEVICE_CUDA) // store snd no later than fst. atomicStore(&(triggerPtr->snd), trigger.snd, memoryOrderRelaxed); diff --git a/include/mscclpp/memory_channel_device.hpp b/include/mscclpp/memory_channel_device.hpp index d49eb4de..1cf6a853 100644 --- a/include/mscclpp/memory_channel_device.hpp +++ b/include/mscclpp/memory_channel_device.hpp @@ -273,6 +273,14 @@ struct MemoryChannelDeviceHandle { /// Wait for the remote semaphore to send a signal. /// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative. MSCCLPP_DEVICE_INLINE void wait(int64_t maxSpinCount = 10000000) { semaphore_.wait(maxSpinCount); } + + /// Wait for the remote semaphore to send a signal. + /// + /// This function is a relaxed version of signal() and provides no guarantee on the completion of memory operations. + /// User requires to call proper fencing before using this function. + /// + /// @param maxSpinCount The maximum number of spins before asserting. Never assert if negative. + MSCCLPP_DEVICE_INLINE void relaxedWait() { semaphore_.relaxedWait(); } #endif // defined(MSCCLPP_DEVICE_COMPILE) }; diff --git a/include/mscclpp/semaphore_device.hpp b/include/mscclpp/semaphore_device.hpp index 31dff14d..88ca9882 100644 --- a/include/mscclpp/semaphore_device.hpp +++ b/include/mscclpp/semaphore_device.hpp @@ -54,6 +54,17 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle { POLL_MAYBE_JAILBREAK((atomicLoad(inboundSemaphoreId, memoryOrderAcquire) < flag), maxSpinCount); } + /// Wait for the remote device to signal. + /// + /// This function is a relaxed version of Wait() and provides no guarantee on the completion of memory operations. + /// User requires to call proper fencing before using this function. + /// + MSCCLPP_DEVICE_INLINE void relaxedWait(int64_t maxSpinCount = 100000000) { + (*expectedInboundSemaphoreId) += 1; + uint64_t flag = (*expectedInboundSemaphoreId); + POLL_MAYBE_JAILBREAK((atomicLoad(inboundSemaphoreId, memoryOrderRelaxed) < flag), maxSpinCount); + } + /// Signal the remote device. /// /// This function guarantees that all the memory operation before this function is completed before the remote @@ -65,7 +76,11 @@ struct MemoryDevice2DeviceSemaphoreDeviceHandle { semaphoreIncrement(); // use memoryOrderSeqCst instead of memoryOrderRelease since memoryOrderSeqCst // is more efficient on A100. +#if __CUDA_ARCH__ == 800 atomicStore(remoteInboundSemaphoreId, semaphoreGetLocal(), memoryOrderSeqCst); +#else + atomicStore(remoteInboundSemaphoreId, semaphoreGetLocal(), memoryOrderRelease); +#endif } /// Signal the remote device. diff --git a/python/mscclpp_benchmark/allreduce.cu b/python/mscclpp_benchmark/allreduce.cu index dbe376a3..ce124da5 100644 --- a/python/mscclpp_benchmark/allreduce.cu +++ b/python/mscclpp_benchmark/allreduce.cu @@ -134,16 +134,11 @@ __device__ void allreduce1_helper(mscclpp::MemoryChannelDeviceHandle* memChans, int4* buff4 = (int4*)buff; const int tid = threadIdx.x + blockIdx.x * blockDim.x; - // synchronize everyone - if (tid == 0) { - __threadfence_system(); - } - __syncthreads(); if (tid < nPeer) { memChans[tid].relaxedSignal(); } if (tid >= nPeer && tid < nPeer * 2) { - memChans[tid - nPeer].wait(); + memChans[tid - nPeer].relaxedWait(); } deviceSyncer.sync(gridDim.x); @@ -193,15 +188,11 @@ __device__ void allreduce1_helper(mscclpp::MemoryChannelDeviceHandle* memChans, // synchronize everyone again deviceSyncer.sync(gridDim.x); - if (tid == 0) { - __threadfence_system(); - } - __syncthreads(); if (tid < nPeer) { memChans[tid].relaxedSignal(); } if (tid >= nPeer && tid < nPeer * 2) { - memChans[tid - nPeer].wait(); + memChans[tid - nPeer].relaxedWait(); } if (READ_ONLY) { @@ -437,7 +428,7 @@ __device__ void localReduceScatterMem(mscclpp::MemoryChannelDeviceHandle* memCha memChans[peerIdx].relaxedSignal(); } for (int peerIdx = threadIdx.x + blockIdx.x * blockDim.x; peerIdx < nPeer; peerIdx += blockDim.x * nBlocks) { - memChans[peerIdx].wait(); + memChans[peerIdx].relaxedWait(); } reduceScatterDeviceSyncer.sync(nBlocks); @@ -497,7 +488,7 @@ __device__ void localAllGatherMem(mscclpp::MemoryChannelDeviceHandle* memChans, } if (threadIdx.x == 0 && peerLocalBlockIdx == 0) { memChans[peerIdx].relaxedSignal(); - memChans[peerIdx].wait(); + memChans[peerIdx].relaxedWait(); } allGatherDeviceSyncer.sync(nBlocks); size_t offset = rankChunkSize * (startRankChunkIndex + remoteRankLocalIndex) + offsetInRankChunk; diff --git a/test/deploy/perf_ndmv4.jsonl b/test/deploy/perf_ndmv4.jsonl index 75799885..ed8bc17a 100644 --- a/test/deploy/perf_ndmv4.jsonl +++ b/test/deploy/perf_ndmv4.jsonl @@ -3,7 +3,7 @@ {"name":"allgather", "kernel":3, "ranks":8, "ranksPerNode":8, "algBw":0.1112, "busBw":0.0973, "size":8192, "time":73.63, "target":"latency"} {"name":"allreduce", "kernel":1, "ranks":8, "ranksPerNode":8, "algBw":139.41, "busBw":243.96, "size":1073741824, "time":7701.98, "target":"throughput"} {"name":"allreduce", "kernel":2, "ranks":8, "ranksPerNode":8, "algBw":1.25, "busBw":2.19, "size":8192, "time":6.51, "target":"latency"} -{"name":"allreduce", "kernel":2, "ranks":16,"ranksPerNode":8, "algBw":0.51, "busBw":0.96, "size":8192, "time":15.96, "target":"latency"} +{"name":"allreduce", "kernel":2, "ranks":16,"ranksPerNode":8, "algBw":0.44, "busBw":0.83, "size":8192, "time":18.42, "target":"latency"} {"name":"allreduce", "kernel":3, "ranks":8, "ranksPerNode":8, "algBw":139.08, "busBw":243.40, "size":1073741824, "time":7719.85, "target":"throughput"} {"name":"allreduce", "kernel":4, "ranks":8, "ranksPerNode":8, "algBw":106.98, "busBw":187.22, "size":16777216, "time":156.81, "target":"throughput"} {"name":"allreduce", "kernel":4, "ranks":8, "ranksPerNode":8, "algBw":116.24, "busBw":203.42, "size":33554432, "time":288.65, "target":"throughput"}