From bde23ce38e6399e52d4662018935863d5654fd4a Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 6 May 2026 22:16:08 +0000 Subject: [PATCH] Revert verbose RSAG zero-copy comment; rename NRanksPerNode template param - Restore the original two-line note about the templated peer-loop unrolling instead of the multi-paragraph rationale block. - Rename the kernel template parameter from NRanksPerNode to NRanks. The IPC domain can span multiple physical hosts under MNNVL, so the 'PerNode' suffix is misleading; NRanks matches the runtime ipcDomainNranks parameter that drives template dispatch. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../allreduce/allreduce_rsag_zero_copy.cu | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index a20756ae..c678c267 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -35,17 +35,10 @@ __device__ mscclpp::DeviceSyncer globalSyncer; // // This approach requires registering both input and output buffers as remote // memories (2 * nPeers handles), but avoids scratch buffer allocation and -// the extra copy steps of the standard RSAG. -// -// The kernel is templated on NRanksPerNode so the compiler can keep an int4 -// register array of NPeers elements, #pragma unroll the peer loops, and turn -// the per-iteration modulo into a single AND. This issues all NPeers remote -// reads in parallel so their latency is overlapped instead of serialized. -// Only small fixed sizes ({4, 8}) are instantiated; larger MNNVL domains -// (where the int4 array would spill out of registers) must use a different -// algorithm. +// the extra copy steps of the standard RSAG. The NRanks template +// parameter enables compile-time unrolling of peer loops (supports 4 or 8). -template +template __global__ void __launch_bounds__(1024, 1) allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, DeviceHandle* switchChannels, void* remoteMemories, int rank, int worldSize, @@ -55,12 +48,12 @@ __global__ void __launch_bounds__(1024, 1) assert((uintptr_t)buff % sizeof(int4) == 0); assert((uintptr_t)resultBuff % sizeof(int4) == 0); - constexpr int NPeers = NRanksPerNode - 1; + constexpr int NPeers = NRanks - 1; constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T); constexpr uint32_t outputRemoteBufferOffset = NPeers; - uint32_t alignedNelems = ((nelems + NRanksPerNode - 1) / NRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 * - nelemsPerInt4 * NRanksPerNode; - uint32_t nelemsPerRank = alignedNelems / NRanksPerNode; + uint32_t alignedNelems = + ((nelems + NRanks - 1) / NRanks + nelemsPerInt4 - 1) / nelemsPerInt4 * nelemsPerInt4 * NRanks; + uint32_t nelemsPerRank = alignedNelems / NRanks; uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4; uint32_t nInt4Total = (nelems + nelemsPerInt4 - 1) / nelemsPerInt4; @@ -93,7 +86,7 @@ __global__ void __launch_bounds__(1024, 1) int4 tmp_raw = buff4[offset]; #pragma unroll for (int i = 0; i < NPeers; i++) { - int rankIdx = (rank + i + 1) % NRanksPerNode; + int rankIdx = (rank + i + 1) % NRanks; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; data[i] = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); } @@ -105,7 +98,7 @@ __global__ void __launch_bounds__(1024, 1) int4 tmp = mscclpp::downcastVector(acc); #pragma unroll for (int i = 0; i < NPeers; i++) { - int rankIdx = (rank + i + 1) % NRanksPerNode; + int rankIdx = (rank + i + 1) % NRanks; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; mscclpp::write(((void**)remoteMemories)[outputRemoteBufferOffset + peerIdx], offset, tmp); }