Revert verbose RSAG zero-copy comment; rename NRanksPerNode template param

- Restore the original two-line note about the templated peer-loop
  unrolling instead of the multi-paragraph rationale block.
- Rename the kernel template parameter from NRanksPerNode to NRanks.
  The IPC domain can span multiple physical hosts under MNNVL, so the
  'PerNode' suffix is misleading; NRanks matches the runtime
  ipcDomainNranks parameter that drives template dispatch.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-05-06 22:16:08 +00:00
parent f0c6ac081f
commit bde23ce38e

View File

@@ -35,17 +35,10 @@ __device__ mscclpp::DeviceSyncer globalSyncer;
//
// This approach requires registering both input and output buffers as remote
// memories (2 * nPeers handles), but avoids scratch buffer allocation and
// the extra copy steps of the standard RSAG.
//
// The kernel is templated on NRanksPerNode so the compiler can keep an int4
// register array of NPeers elements, #pragma unroll the peer loops, and turn
// the per-iteration modulo into a single AND. This issues all NPeers remote
// reads in parallel so their latency is overlapped instead of serialized.
// Only small fixed sizes ({4, 8}) are instantiated; larger MNNVL domains
// (where the int4 array would spill out of registers) must use a different
// algorithm.
// the extra copy steps of the standard RSAG. The NRanks template
// parameter enables compile-time unrolling of peer loops (supports 4 or 8).
template <int NRanksPerNode, ReduceOp OpType, typename T, typename AccumT = T>
template <int NRanks, ReduceOp OpType, typename T, typename AccumT = T>
__global__ void __launch_bounds__(1024, 1)
allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank, int worldSize,
@@ -55,12 +48,12 @@ __global__ void __launch_bounds__(1024, 1)
assert((uintptr_t)buff % sizeof(int4) == 0);
assert((uintptr_t)resultBuff % sizeof(int4) == 0);
constexpr int NPeers = NRanksPerNode - 1;
constexpr int NPeers = NRanks - 1;
constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T);
constexpr uint32_t outputRemoteBufferOffset = NPeers;
uint32_t alignedNelems = ((nelems + NRanksPerNode - 1) / NRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 *
nelemsPerInt4 * NRanksPerNode;
uint32_t nelemsPerRank = alignedNelems / NRanksPerNode;
uint32_t alignedNelems =
((nelems + NRanks - 1) / NRanks + nelemsPerInt4 - 1) / nelemsPerInt4 * nelemsPerInt4 * NRanks;
uint32_t nelemsPerRank = alignedNelems / NRanks;
uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4;
uint32_t nInt4Total = (nelems + nelemsPerInt4 - 1) / nelemsPerInt4;
@@ -93,7 +86,7 @@ __global__ void __launch_bounds__(1024, 1)
int4 tmp_raw = buff4[offset];
#pragma unroll
for (int i = 0; i < NPeers; i++) {
int rankIdx = (rank + i + 1) % NRanksPerNode;
int rankIdx = (rank + i + 1) % NRanks;
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
data[i] = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
}
@@ -105,7 +98,7 @@ __global__ void __launch_bounds__(1024, 1)
int4 tmp = mscclpp::downcastVector<T, AccumT, int4>(acc);
#pragma unroll
for (int i = 0; i < NPeers; i++) {
int rankIdx = (rank + i + 1) % NRanksPerNode;
int rankIdx = (rank + i + 1) % NRanks;
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
mscclpp::write<int4>(((void**)remoteMemories)[outputRemoteBufferOffset + peerIdx], offset, tmp);
}