mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
Restore compile-time templated NRanksPerNode for rsag_zero_copy
Recovers the per-thread int4 register array + #pragma unroll for the
{4, 8} rank cases. All NPeers remote reads are issued in parallel so
their latency overlaps instead of being serialized by the runtime
fused load+reduce loop. The runtime-domain (NVL72) fallback is
removed; the algo now returns cudaErrorInvalidValue for unsupported
ipcDomainNranks, and rsag_zero_copy is dropped from the MNNVL
candidate list in the tuning example.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -246,9 +246,6 @@ class CustomizedComm:
|
||||
if a:
|
||||
out.append(a)
|
||||
if size >= 512 << 10:
|
||||
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
|
||||
if a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
|
||||
if self._nvls and self.symmetric_memory and a:
|
||||
out.append(a)
|
||||
|
||||
@@ -35,26 +35,32 @@ __device__ mscclpp::DeviceSyncer globalSyncer;
|
||||
//
|
||||
// This approach requires registering both input and output buffers as remote
|
||||
// memories (2 * nPeers handles), but avoids scratch buffer allocation and
|
||||
// the extra copy steps of the standard RSAG. ipcDomainNranks is accepted at
|
||||
// runtime, which allows the same kernel to handle any NVLink-domain size
|
||||
// (including Multi-Node NVLink fabrics up to NVL72).
|
||||
// the extra copy steps of the standard RSAG.
|
||||
//
|
||||
// The kernel is templated on NRanksPerNode so the compiler can keep an int4
|
||||
// register array of NPeers elements, #pragma unroll the peer loops, and turn
|
||||
// the per-iteration modulo into a single AND. This issues all NPeers remote
|
||||
// reads in parallel so their latency is overlapped instead of serialized.
|
||||
// Only small fixed sizes ({4, 8}) are instantiated; larger MNNVL domains
|
||||
// (where the int4 array would spill out of registers) must use a different
|
||||
// algorithm.
|
||||
|
||||
template <ReduceOp OpType, typename T, typename AccumT = T>
|
||||
template <int NRanksPerNode, ReduceOp OpType, typename T, typename AccumT = T>
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle<BaseMemoryChannel>* memoryChannels,
|
||||
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank,
|
||||
int ipcDomainNranks, int worldSize, size_t nelems) {
|
||||
DeviceHandle<SwitchChannel>* switchChannels, void* remoteMemories, int rank, int worldSize,
|
||||
size_t nelems) {
|
||||
int blockId = blockIdx.x;
|
||||
|
||||
assert((uintptr_t)buff % sizeof(int4) == 0);
|
||||
assert((uintptr_t)resultBuff % sizeof(int4) == 0);
|
||||
|
||||
const int NPeers = ipcDomainNranks - 1;
|
||||
constexpr int NPeers = NRanksPerNode - 1;
|
||||
constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T);
|
||||
const uint32_t outputRemoteBufferOffset = NPeers;
|
||||
uint32_t alignedNelems = ((nelems + ipcDomainNranks - 1) / ipcDomainNranks + nelemsPerInt4 - 1) / nelemsPerInt4 *
|
||||
nelemsPerInt4 * ipcDomainNranks;
|
||||
uint32_t nelemsPerRank = alignedNelems / ipcDomainNranks;
|
||||
constexpr uint32_t outputRemoteBufferOffset = NPeers;
|
||||
uint32_t alignedNelems = ((nelems + NRanksPerNode - 1) / NRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 *
|
||||
nelemsPerInt4 * NRanksPerNode;
|
||||
uint32_t nelemsPerRank = alignedNelems / NRanksPerNode;
|
||||
uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4;
|
||||
uint32_t nInt4Total = (nelems + nelemsPerInt4 - 1) / nelemsPerInt4;
|
||||
|
||||
@@ -75,6 +81,7 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
memoryChannelsLocal[threadIdx.x].relaxedWait();
|
||||
}
|
||||
__syncthreads();
|
||||
int4 data[NPeers];
|
||||
// AccumInt4: when AccumT != T, use a wider accumulator type.
|
||||
// For AccumT == T, this is just int4 (no-op conversion).
|
||||
constexpr int nElemsPerInt4 = sizeof(int4) / sizeof(T);
|
||||
@@ -84,17 +91,21 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
uint32_t offset = idx + offset4 + rank * nInt4PerRank;
|
||||
if (offset >= nInt4Total) continue;
|
||||
int4 tmp_raw = buff4[offset];
|
||||
int4 data;
|
||||
AccumVec acc = mscclpp::upcastVector<T, AccumT, AccumVec>(tmp_raw);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPeers; i++) {
|
||||
int rankIdx = (rank + i + 1) % ipcDomainNranks;
|
||||
int rankIdx = (rank + i + 1) % NRanksPerNode;
|
||||
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
|
||||
data = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
|
||||
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccumVec>(acc, data);
|
||||
data[i] = mscclpp::read<int4>(((void**)remoteMemories)[peerIdx], offset);
|
||||
}
|
||||
AccumVec acc = mscclpp::upcastVector<T, AccumT, AccumVec>(tmp_raw);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPeers; i++) {
|
||||
acc = mscclpp::calVectorAccum<T, AccumT, OpType, AccumVec>(acc, data[i]);
|
||||
}
|
||||
int4 tmp = mscclpp::downcastVector<T, AccumT, int4>(acc);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NPeers; i++) {
|
||||
int rankIdx = (rank + i + 1) % ipcDomainNranks;
|
||||
int rankIdx = (rank + i + 1) % NRanksPerNode;
|
||||
int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1;
|
||||
mscclpp::write<int4>(((void**)remoteMemories)[outputRemoteBufferOffset + peerIdx], offset, tmp);
|
||||
}
|
||||
@@ -123,9 +134,18 @@ struct AllreduceRsAgZeroCopyAdapter {
|
||||
nBlocks = 128;
|
||||
}
|
||||
}
|
||||
allreduceRsAgZeroCopy<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank,
|
||||
ipcDomainNranks, worldSize, nelems);
|
||||
if (ipcDomainNranks == 4) {
|
||||
allreduceRsAgZeroCopy<4, OpType, T, AccumT>
|
||||
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
|
||||
switchChannel, remoteMemories, rank, worldSize, nelems);
|
||||
} else if (ipcDomainNranks == 8) {
|
||||
allreduceRsAgZeroCopy<8, OpType, T, AccumT>
|
||||
<<<nBlocks, nThreadsPerBlock, 0, stream>>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels,
|
||||
switchChannel, remoteMemories, rank, worldSize, nelems);
|
||||
} else {
|
||||
WARN(ALGO, "AllreduceRsAgZeroCopy only supports ipcDomainNranks of 4 or 8, got: ", ipcDomainNranks);
|
||||
return cudaErrorInvalidValue;
|
||||
}
|
||||
return cudaGetLastError();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user