Restore compile-time templated NRanksPerNode for rsag_zero_copy

Recovers the per-thread int4 register array + #pragma unroll for the
{4, 8} rank cases. All NPeers remote reads are issued in parallel so
their latency overlaps instead of being serialized by the runtime
fused load+reduce loop. The runtime-domain (NVL72) fallback is
removed; the algo now returns cudaErrorInvalidValue for unsupported
ipcDomainNranks, and rsag_zero_copy is dropped from the MNNVL
candidate list in the tuning example.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-05-01 23:09:22 +00:00
parent 2a2fca8a58
commit 2efda4d819
2 changed files with 40 additions and 23 deletions

View File

@@ -246,9 +246,6 @@ class CustomizedComm:
if a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
if self._nvls and self.symmetric_memory and a:
out.append(a)

View File

@@ -35,26 +35,32 @@ __device__ mscclpp::DeviceSyncer globalSyncer;
//
// This approach requires registering both input and output buffers as remote
// memories (2 * nPeers handles), but avoids scratch buffer allocation and
// the extra copy steps of the standard RSAG. ipcDomainNranks is accepted at
// runtime, which allows the same kernel to handle any NVLink-domain size
// (including Multi-Node NVLink fabrics up to NVL72).
// the extra copy steps of the standard RSAG.
//
// The kernel is templated on NRanksPerNode so the compiler can keep an int4
// register array of NPeers elements, #pragma unroll the peer loops, and turn
// the per-iteration modulo into a single AND. This issues all NPeers remote
// reads in parallel so their latency is overlapped instead of serialized.
// Only small fixed sizes ({4, 8}) are instantiated; larger MNNVL domains
// (where the int4 array would spill out of registers) must use a different
// algorithm.
template <ReduceOp OpType, typename T, typename AccumT = T>
template <int NRanksPerNode, ReduceOp OpType, typename T, typename AccumT = T>
__global__ void __launch_bounds__(1024, 1)
allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank,
int ipcDomainNranks, int worldSize, size_t nelems) {
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank, int worldSize,
size_t nelems) {
int blockId = blockIdx.x;
assert((uintptr_t)buff % sizeof(int4) == 0);
assert((uintptr_t)resultBuff % sizeof(int4) == 0);
const int NPeers = ipcDomainNranks - 1;
constexpr int NPeers = NRanksPerNode - 1;
constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T);
const uint32_t outputRemoteBufferOffset = NPeers;
uint32_t alignedNelems = ((nelems + ipcDomainNranks - 1) / ipcDomainNranks + nelemsPerInt4 - 1) / nelemsPerInt4 *
nelemsPerInt4 * ipcDomainNranks;
uint32_t nelemsPerRank = alignedNelems / ipcDomainNranks;
constexpr uint32_t outputRemoteBufferOffset = NPeers;
uint32_t alignedNelems = ((nelems + NRanksPerNode - 1) / NRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 *
nelemsPerInt4 * NRanksPerNode;
uint32_t nelemsPerRank = alignedNelems / NRanksPerNode;
uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4;
uint32_t nInt4Total = (nelems + nelemsPerInt4 - 1) / nelemsPerInt4;
@@ -75,6 +81,7 @@ __global__ void __launch_bounds__(1024, 1)
memoryChannelsLocal[threadIdx.x].relaxedWait();
}
__syncthreads();
int4 data[NPeers];
// AccumInt4: when AccumT != T, use a wider accumulator type.
// For AccumT == T, this is just int4 (no-op conversion).
constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T);
@@ -84,17 +91,21 @@ __global__ void __launch_bounds__(1024, 1)
uint32_t offset = idx + offset4 + rank * nInt4PerRank;
if (offset >= nInt4Total) continue;
int4 tmp_raw = buff4[offset];
int4 data;
AccumVec acc = mscclpp::upcastVector<T, AccumT, AccumVec>(tmp_raw);
#pragma unroll
for (int i = 0; i < NPeers; i++) {
int rankIdx = (rank + i + 1) % ipcDomainNranks;
int rankIdx = (rank + i + 1) % NRanksPerNode;
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
data = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccumVec>(acc, data);
data[i] = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
}
AccumVec acc = mscclpp::upcastVector<T, AccumT, AccumVec>(tmp_raw);
#pragma unroll
for (int i = 0; i < NPeers; i++) {
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccumVec>(acc, data[i]);
}
int4 tmp = mscclpp::downcastVector<T, AccumT, int4>(acc);
#pragma unroll
for (int i = 0; i < NPeers; i++) {
int rankIdx = (rank + i + 1) % ipcDomainNranks;
int rankIdx = (rank + i + 1) % NRanksPerNode;
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
mscclpp::write<int4>(((void**)remoteMemories)[outputRemoteBufferOffset + peerIdx], offset, tmp);
}
@@ -123,9 +134,18 @@ struct AllreduceRsAgZeroCopyAdapter {
nBlocks = 128;
}
}
allreduceRsAgZeroCopy<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank,
ipcDomainNranks, worldSize, nelems);
if (ipcDomainNranks == 4) {
allreduceRsAgZeroCopy<4, OpType, T, AccumT>
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
switchChannel, remoteMemories, rank, worldSize, nelems);
} else if (ipcDomainNranks == 8) {
allreduceRsAgZeroCopy<8, OpType, T, AccumT>
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
switchChannel, remoteMemories, rank, worldSize, nelems);
} else {
WARN(ALGO, "AllreduceRsAgZeroCopy only supports ipcDomainNranks of 4 or 8, got: ", ipcDomainNranks);
return cudaErrorInvalidValue;
}
return cudaGetLastError();
}
};