From d1b04a3b26567f7e27c0eefb52b8e4dcc273874a Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Thu, 7 May 2026 00:38:31 +0000 Subject: [PATCH] NVLS zero-copy allreduce: support FP16 accumulator for FP8 inputs multimem.ld_reduce on FP8 inputs accumulates in FP32 by default. The ISA also exposes an .acc::f16 variant that keeps the reduction in FP16, which is faster but lower precision. Plumb AccumT through: - include/mscclpp/switch_channel_device.hpp: Extend SwitchChannelDeviceHandle::multimemLoadReduce with an optional AccumT template parameter. When VectorType is one of the FP8 vector types (f8_e4m3x{4,8,16} / f8_e5m2x{4,8,16}) and AccumT is __half, emit the .acc::f16 form of the instruction; otherwise unchanged. - src/ext/collectives/include/allreduce/common.hpp: Make handleMultiLoadReduceStore template on AccumT and forward it to multimemLoadReduce(...). - src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu: Template allreduceNvls and NvlsAdapter on AccumT and forward to handleMultiLoadReduceStore; the existing dispatch<> machinery already plumbs AccumT through from the algorithm context. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../customized_comm_with_tuning.py | 51 ++++++----- include/mscclpp/switch_channel_device.hpp | 84 ++++++++++++++----- .../allreduce/allreduce_nvls_zero_copy.cu | 12 +-- .../collectives/include/allreduce/common.hpp | 7 +- 4 files changed, 100 insertions(+), 54 deletions(-) diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index cbfb419d..44a5c9c1 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -76,6 +76,25 @@ class CustomizedComm: "default_allreduce_fullmesh": 64, "default_allgather_fullmesh2": 32, } + # (algo_name, min_size, max_size, predicate) + # Boundaries are inclusive on both ends. max_size=None means unbounded. + # predicate=None means always applicable; otherwise a callable taking `self`. + _AR_CANDIDATES_MNNVL = [ + ("default_allreduce_allpair_packet", 0, 128 << 10, None), + ("default_allreduce_nvls_packet", 0, 64 << 10, lambda c: c._nvls), + ("default_allreduce_packet", 128 << 10, 4 << 20, None), + ("default_allreduce_nvls_zero_copy", 512 << 10, None, lambda c: c._nvls and c.symmetric_memory), + ("default_allreduce_rsag_zero_copy", 512 << 10, None, None), + ("default_allreduce_rsag", 512 << 10, None, None), + ] + _AR_CANDIDATES_SINGLE = [ + ("default_allreduce_packet", 0, 4 << 20, None), + ("default_allreduce_allpair_packet", 0, 4 << 20, None), + ("default_allreduce_nvls_packet", 0, 4 << 20, lambda c: c._nvls), + ("default_allreduce_rsag_zero_copy", 512 << 10, None, None), + ("default_allreduce_nvls_zero_copy", 512 << 10, None, lambda c: c._nvls and c.symmetric_memory), + ("default_allreduce_fullmesh", 0, None, lambda c: torch.version.hip is not None), + ] def __init__(self, comm: mscclpp.CommGroup, symmetric_memory: bool = False): self.comm = comm @@ -164,32 +183,12 @@ class CustomizedComm: return self._tune_buf def _ar_candidates(self, size: int): - out = [] - if self.multi_host_mnnvl: - if size <= 4 << 20: - if size <= 128 << 10: - out.append(self._algo("allreduce", "default_allreduce_allpair_packet")) - if size <= 64 << 10 and self._nvls: - out.append(self._algo("allreduce", "default_allreduce_nvls_packet")) - if size > 128 << 10: - out.append(self._algo("allreduce", "default_allreduce_packet")) - if size >= 512 << 10: - if self._nvls and self.symmetric_memory: - out.append(self._algo("allreduce", "default_allreduce_nvls_zero_copy")) - out.append(self._algo("allreduce", "default_allreduce_rsag")) - return out - if size <= 4 << 20: - out.append(self._algo("allreduce", "default_allreduce_packet")) - out.append(self._algo("allreduce", "default_allreduce_allpair_packet")) - if self._nvls: - out.append(self._algo("allreduce", "default_allreduce_nvls_packet")) - if size >= 512 << 10: - out.append(self._algo("allreduce", "default_allreduce_rsag_zero_copy")) - if self._nvls and self.symmetric_memory: - out.append(self._algo("allreduce", "default_allreduce_nvls_zero_copy")) - if torch.version.hip is not None: - out.append(self._algo("allreduce", "default_allreduce_fullmesh")) - return out + table = self._AR_CANDIDATES_MNNVL if self.multi_host_mnnvl else self._AR_CANDIDATES_SINGLE + return [ + self._algo("allreduce", name) + for name, lo, hi, pred in table + if size >= lo and (hi is None or size <= hi) and (pred is None or pred(self)) + ] def _ag_candidates(self): if self.multi_host_mnnvl: diff --git a/include/mscclpp/switch_channel_device.hpp b/include/mscclpp/switch_channel_device.hpp index b52b6572..7b749f7a 100644 --- a/include/mscclpp/switch_channel_device.hpp +++ b/include/mscclpp/switch_channel_device.hpp @@ -37,7 +37,11 @@ struct SwitchChannelDeviceHandle { SwitchChannelDeviceHandle::multimemStore(val, reinterpret_cast(mcPtr) + index); } - template + /// Vectorized multimem load+reduce. The optional `AccumT` template parameter selects the + /// accumulator: when `AccumT == __half` and `VectorType` is an FP8 vector type, the + /// `.acc::f16` variant of the instruction is used (faster but lower precision than the + /// default FP32 accumulator). For all other types `AccumT` is ignored. + template MSCCLPP_DEVICE_INLINE static VectorType multimemLoadReduce(VectorType* ptr) { VectorType val; if constexpr (std::is_same_v) { @@ -81,29 +85,71 @@ struct SwitchChannelDeviceHandle { : "l"(ptr) : "memory"); } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.e4m3x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.e4m3x4 %0, [%1];" + : "=r"(val.words[0]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.e4m3x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory"); + } } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e4m3x4 {%0,%1}, [%2];" - : "=r"(val.words[0]), "=r"(val.words[1]) - : "l"(ptr) - : "memory"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v2.e4m3x4 {%0,%1}, [%2];" + : "=r"(val.words[0]), "=r"(val.words[1]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e4m3x4 {%0,%1}, [%2];" + : "=r"(val.words[0]), "=r"(val.words[1]) + : "l"(ptr) + : "memory"); + } } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e4m3x4 {%0,%1,%2,%3}, [%4];" - : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) - : "l"(ptr) - : "memory"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v4.e4m3x4 {%0,%1,%2,%3}, [%4];" + : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e4m3x4 {%0,%1,%2,%3}, [%4];" + : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) + : "l"(ptr) + : "memory"); + } } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.e5m2x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.e5m2x4 %0, [%1];" + : "=r"(val.words[0]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.e5m2x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory"); + } } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e5m2x4 {%0,%1}, [%2];" - : "=r"(val.words[0]), "=r"(val.words[1]) - : "l"(ptr) - : "memory"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v2.e5m2x4 {%0,%1}, [%2];" + : "=r"(val.words[0]), "=r"(val.words[1]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e5m2x4 {%0,%1}, [%2];" + : "=r"(val.words[0]), "=r"(val.words[1]) + : "l"(ptr) + : "memory"); + } } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e5m2x4 {%0,%1,%2,%3}, [%4];" - : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) - : "l"(ptr) - : "memory"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v4.e5m2x4 {%0,%1,%2,%3}, [%4];" + : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e5m2x4 {%0,%1,%2,%3}, [%4];" + : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) + : "l"(ptr) + : "memory"); + } } else { static_assert(dependentFalse, "Not supported type"); } diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index 115a229a..99146779 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -14,7 +14,7 @@ namespace collective { constexpr int MAX_NBLOCKS = 32; -template +template __global__ void __launch_bounds__(1024, 1) allreduceNvls([[maybe_unused]] mscclpp::DeviceHandle* memoryChannels, [[maybe_unused]] mscclpp::DeviceHandle* multicast, @@ -58,8 +58,8 @@ __global__ void __launch_bounds__(1024, 1) T* src = (T*)multicastPtr->mcPtr; T* dst = (T*)multicastOutPtr->mcPtr; if (curBlockSize > 0) { - handleMultiLoadReduceStore(src, dst, blockOffset + channelInOffset, blockOffset + channelOutOffset, curBlockSize, - threadIdx.x, blockDim.x); + handleMultiLoadReduceStore(src, dst, blockOffset + channelInOffset, blockOffset + channelOutOffset, + curBlockSize, threadIdx.x, blockDim.x); } __syncthreads(); if (threadIdx.x < nPeers) { @@ -90,9 +90,9 @@ struct NvlsAdapter { #endif { using ChannelType = DeviceHandle; - allreduceNvls<<>>((ChannelType*)memoryChannels, nvlsChannels, - nvlsOutChannels, channelInOffset, channelOutOffset, - inputSize, rank, ipcDomainNranks); + allreduceNvls<<>>( + (ChannelType*)memoryChannels, nvlsChannels, nvlsOutChannels, channelInOffset, channelOutOffset, inputSize, + rank, ipcDomainNranks); return cudaGetLastError(); } } diff --git a/src/ext/collectives/include/allreduce/common.hpp b/src/ext/collectives/include/allreduce/common.hpp index 93b18e26..22513ace 100644 --- a/src/ext/collectives/include/allreduce/common.hpp +++ b/src/ext/collectives/include/allreduce/common.hpp @@ -36,7 +36,7 @@ MSCCLPP_DEVICE_INLINE constexpr std::size_t calcVectorSize() { } } -template +template MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* src, T* dst, size_t srcOffset, size_t dstOffset, size_t size, int tid, int nThreads) { // nvls can only handle 4 bytes alignment @@ -54,7 +54,7 @@ MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* src, T* dst, size_t src vectorType* src4 = (vectorType*)src; vectorType* dst4 = (vectorType*)dst; for (size_t idx = tid; idx < nVec; idx += nThreads) { - auto val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce(src4 + srcOffset4 + idx); + auto val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce(src4 + srcOffset4 + idx); mscclpp::SwitchChannelDeviceHandle::multimemStore(val, dst4 + dstOffset4 + idx); } // handle rest of data @@ -64,7 +64,8 @@ MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* src, T* dst, size_t src const size_t startIdx = (srcOffset + processed) / sizeof(restVectorType); const size_t endIdx = (srcOffset + size) / sizeof(restVectorType); for (size_t idx = tid + startIdx; idx < endIdx; idx += nThreads) { - auto val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce((restVectorType*)src + idx); + auto val = + mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce((restVectorType*)src + idx); mscclpp::SwitchChannelDeviceHandle::multimemStore(val, (restVectorType*)dst + idx); } }