diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 1d54cfa7..9ad7f22a 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -246,9 +246,6 @@ class CustomizedComm: if a: out.append(a) if size >= 512 << 10: - a = self._algo("allreduce", "default_allreduce_rsag_zero_copy") - if a: - out.append(a) a = self._algo("allreduce", "default_allreduce_nvls_zero_copy") if self._nvls and self.symmetric_memory and a: out.append(a) diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index ea664325..09fa2fe7 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -35,26 +35,32 @@ __device__ mscclpp::DeviceSyncer globalSyncer; // // This approach requires registering both input and output buffers as remote // memories (2 * nPeers handles), but avoids scratch buffer allocation and -// the extra copy steps of the standard RSAG. ipcDomainNranks is accepted at -// runtime, which allows the same kernel to handle any NVLink-domain size -// (including Multi-Node NVLink fabrics up to NVL72). +// the extra copy steps of the standard RSAG. +// +// The kernel is templated on NRanksPerNode so the compiler can keep an int4 +// register array of NPeers elements, #pragma unroll the peer loops, and turn +// the per-iteration modulo into a single AND. This issues all NPeers remote +// reads in parallel so their latency is overlapped instead of serialized. +// Only small fixed sizes ({4, 8}) are instantiated; larger MNNVL domains +// (where the int4 array would spill out of registers) must use a different +// algorithm. -template +template __global__ void __launch_bounds__(1024, 1) allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, - DeviceHandle* switchChannels, void* remoteMemories, int rank, - int ipcDomainNranks, int worldSize, size_t nelems) { + DeviceHandle* switchChannels, void* remoteMemories, int rank, int worldSize, + size_t nelems) { int blockId = blockIdx.x; assert((uintptr_t)buff % sizeof(int4) == 0); assert((uintptr_t)resultBuff % sizeof(int4) == 0); - const int NPeers = ipcDomainNranks - 1; + constexpr int NPeers = NRanksPerNode - 1; constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T); - const uint32_t outputRemoteBufferOffset = NPeers; - uint32_t alignedNelems = ((nelems + ipcDomainNranks - 1) / ipcDomainNranks + nelemsPerInt4 - 1) / nelemsPerInt4 * - nelemsPerInt4 * ipcDomainNranks; - uint32_t nelemsPerRank = alignedNelems / ipcDomainNranks; + constexpr uint32_t outputRemoteBufferOffset = NPeers; + uint32_t alignedNelems = ((nelems + NRanksPerNode - 1) / NRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 * + nelemsPerInt4 * NRanksPerNode; + uint32_t nelemsPerRank = alignedNelems / NRanksPerNode; uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4; uint32_t nInt4Total = (nelems + nelemsPerInt4 - 1) / nelemsPerInt4; @@ -75,6 +81,7 @@ __global__ void __launch_bounds__(1024, 1) memoryChannelsLocal[threadIdx.x].relaxedWait(); } __syncthreads(); + int4 data[NPeers]; // AccumInt4: when AccumT != T, use a wider accumulator type. // For AccumT == T, this is just int4 (no-op conversion). constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T); @@ -84,17 +91,21 @@ __global__ void __launch_bounds__(1024, 1) uint32_t offset = idx + offset4 + rank * nInt4PerRank; if (offset >= nInt4Total) continue; int4 tmp_raw = buff4[offset]; - int4 data; - AccumVec acc = mscclpp::upcastVector(tmp_raw); +#pragma unroll for (int i = 0; i < NPeers; i++) { - int rankIdx = (rank + i + 1) % ipcDomainNranks; + int rankIdx = (rank + i + 1) % NRanksPerNode; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; - data = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); - acc = mscclpp::calVectorAccum(acc, data); + data[i] = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); + } + AccumVec acc = mscclpp::upcastVector(tmp_raw); +#pragma unroll + for (int i = 0; i < NPeers; i++) { + acc = mscclpp::calVectorAccum(acc, data[i]); } int4 tmp = mscclpp::downcastVector(acc); +#pragma unroll for (int i = 0; i < NPeers; i++) { - int rankIdx = (rank + i + 1) % ipcDomainNranks; + int rankIdx = (rank + i + 1) % NRanksPerNode; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; mscclpp::write(((void**)remoteMemories)[outputRemoteBufferOffset + peerIdx], offset, tmp); } @@ -123,9 +134,18 @@ struct AllreduceRsAgZeroCopyAdapter { nBlocks = 128; } } - allreduceRsAgZeroCopy<<>>( - (T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, - ipcDomainNranks, worldSize, nelems); + if (ipcDomainNranks == 4) { + allreduceRsAgZeroCopy<4, OpType, T, AccumT> + <<>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, + switchChannel, remoteMemories, rank, worldSize, nelems); + } else if (ipcDomainNranks == 8) { + allreduceRsAgZeroCopy<8, OpType, T, AccumT> + <<>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, + switchChannel, remoteMemories, rank, worldSize, nelems); + } else { + WARN(ALGO, "AllreduceRsAgZeroCopy only supports ipcDomainNranks of 4 or 8, got: ", ipcDomainNranks); + return cudaErrorInvalidValue; + } return cudaGetLastError(); } };