mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 17:00:22 +00:00
NVLS zero-copy allreduce: support FP16 accumulator for FP8 inputs
multimem.ld_reduce on FP8 inputs accumulates in FP32 by default. The
ISA also exposes an .acc::f16 variant that keeps the reduction in
FP16, which is faster but lower precision. Plumb AccumT through:
- include/mscclpp/switch_channel_device.hpp:
Extend SwitchChannelDeviceHandle::multimemLoadReduce with an optional
AccumT template parameter. When VectorType is one of the FP8 vector
types (f8_e4m3x{4,8,16} / f8_e5m2x{4,8,16}) and AccumT is __half,
emit the .acc::f16 form of the instruction; otherwise unchanged.
- src/ext/collectives/include/allreduce/common.hpp:
Make handleMultiLoadReduceStore template on AccumT and forward it to
multimemLoadReduce<vectorType, AccumT>(...).
- src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu:
Template allreduceNvls and NvlsAdapter on AccumT and forward to
handleMultiLoadReduceStore<T, AccumT>; the existing dispatch<>
machinery already plumbs AccumT through from the algorithm context.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -76,6 +76,25 @@ class CustomizedComm:
|
||||
"default_allreduce_fullmesh": 64,
|
||||
"default_allgather_fullmesh2": 32,
|
||||
}
|
||||
# (algo_name, min_size, max_size, predicate)
|
||||
# Boundaries are inclusive on both ends. max_size=None means unbounded.
|
||||
# predicate=None means always applicable; otherwise a callable taking `self`.
|
||||
_AR_CANDIDATES_MNNVL = [
|
||||
("default_allreduce_allpair_packet", 0, 128 << 10, None),
|
||||
("default_allreduce_nvls_packet", 0, 64 << 10, lambda c: c._nvls),
|
||||
("default_allreduce_packet", 128 << 10, 4 << 20, None),
|
||||
("default_allreduce_nvls_zero_copy", 512 << 10, None, lambda c: c._nvls and c.symmetric_memory),
|
||||
("default_allreduce_rsag_zero_copy", 512 << 10, None, None),
|
||||
("default_allreduce_rsag", 512 << 10, None, None),
|
||||
]
|
||||
_AR_CANDIDATES_SINGLE = [
|
||||
("default_allreduce_packet", 0, 4 << 20, None),
|
||||
("default_allreduce_allpair_packet", 0, 4 << 20, None),
|
||||
("default_allreduce_nvls_packet", 0, 4 << 20, lambda c: c._nvls),
|
||||
("default_allreduce_rsag_zero_copy", 512 << 10, None, None),
|
||||
("default_allreduce_nvls_zero_copy", 512 << 10, None, lambda c: c._nvls and c.symmetric_memory),
|
||||
("default_allreduce_fullmesh", 0, None, lambda c: torch.version.hip is not None),
|
||||
]
|
||||
|
||||
def __init__(self, comm: mscclpp.CommGroup, symmetric_memory: bool = False):
|
||||
self.comm = comm
|
||||
@@ -164,32 +183,12 @@ class CustomizedComm:
|
||||
return self._tune_buf
|
||||
|
||||
def _ar_candidates(self, size: int):
|
||||
out = []
|
||||
if self.multi_host_mnnvl:
|
||||
if size <= 4 << 20:
|
||||
if size <= 128 << 10:
|
||||
out.append(self._algo("allreduce", "default_allreduce_allpair_packet"))
|
||||
if size <= 64 << 10 and self._nvls:
|
||||
out.append(self._algo("allreduce", "default_allreduce_nvls_packet"))
|
||||
if size > 128 << 10:
|
||||
out.append(self._algo("allreduce", "default_allreduce_packet"))
|
||||
if size >= 512 << 10:
|
||||
if self._nvls and self.symmetric_memory:
|
||||
out.append(self._algo("allreduce", "default_allreduce_nvls_zero_copy"))
|
||||
out.append(self._algo("allreduce", "default_allreduce_rsag"))
|
||||
return out
|
||||
if size <= 4 << 20:
|
||||
out.append(self._algo("allreduce", "default_allreduce_packet"))
|
||||
out.append(self._algo("allreduce", "default_allreduce_allpair_packet"))
|
||||
if self._nvls:
|
||||
out.append(self._algo("allreduce", "default_allreduce_nvls_packet"))
|
||||
if size >= 512 << 10:
|
||||
out.append(self._algo("allreduce", "default_allreduce_rsag_zero_copy"))
|
||||
if self._nvls and self.symmetric_memory:
|
||||
out.append(self._algo("allreduce", "default_allreduce_nvls_zero_copy"))
|
||||
if torch.version.hip is not None:
|
||||
out.append(self._algo("allreduce", "default_allreduce_fullmesh"))
|
||||
return out
|
||||
table = self._AR_CANDIDATES_MNNVL if self.multi_host_mnnvl else self._AR_CANDIDATES_SINGLE
|
||||
return [
|
||||
self._algo("allreduce", name)
|
||||
for name, lo, hi, pred in table
|
||||
if size >= lo and (hi is None or size <= hi) and (pred is None or pred(self))
|
||||
]
|
||||
|
||||
def _ag_candidates(self):
|
||||
if self.multi_host_mnnvl:
|
||||
|
||||
@@ -37,7 +37,11 @@ struct SwitchChannelDeviceHandle {
|
||||
SwitchChannelDeviceHandle::multimemStore(val, reinterpret_cast<T*>(mcPtr) + index);
|
||||
}
|
||||
|
||||
template <typename VectorType>
|
||||
/// Vectorized multimem load+reduce. The optional `AccumT` template parameter selects the
|
||||
/// accumulator: when `AccumT == __half` and `VectorType` is an FP8 vector type, the
|
||||
/// `.acc::f16` variant of the instruction is used (faster but lower precision than the
|
||||
/// default FP32 accumulator). For all other types `AccumT` is ignored.
|
||||
template <typename VectorType, typename AccumT = void>
|
||||
MSCCLPP_DEVICE_INLINE static VectorType multimemLoadReduce(VectorType* ptr) {
|
||||
VectorType val;
|
||||
if constexpr (std::is_same_v<VectorType, i32x1>) {
|
||||
@@ -81,29 +85,71 @@ struct SwitchChannelDeviceHandle {
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
} else if constexpr (std::is_same_v<VectorType, f8_e4m3x4>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.e4m3x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory");
|
||||
if constexpr (std::is_same_v<AccumT, __half>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.e4m3x4 %0, [%1];"
|
||||
: "=r"(val.words[0])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
} else {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.e4m3x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory");
|
||||
}
|
||||
} else if constexpr (std::is_same_v<VectorType, f8_e4m3x8>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e4m3x4 {%0,%1}, [%2];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
if constexpr (std::is_same_v<AccumT, __half>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v2.e4m3x4 {%0,%1}, [%2];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
} else {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e4m3x4 {%0,%1}, [%2];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
}
|
||||
} else if constexpr (std::is_same_v<VectorType, f8_e4m3x16>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e4m3x4 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
if constexpr (std::is_same_v<AccumT, __half>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v4.e4m3x4 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
} else {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e4m3x4 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
}
|
||||
} else if constexpr (std::is_same_v<VectorType, f8_e5m2x4>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.e5m2x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory");
|
||||
if constexpr (std::is_same_v<AccumT, __half>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.e5m2x4 %0, [%1];"
|
||||
: "=r"(val.words[0])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
} else {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.e5m2x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory");
|
||||
}
|
||||
} else if constexpr (std::is_same_v<VectorType, f8_e5m2x8>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e5m2x4 {%0,%1}, [%2];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
if constexpr (std::is_same_v<AccumT, __half>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v2.e5m2x4 {%0,%1}, [%2];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
} else {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e5m2x4 {%0,%1}, [%2];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
}
|
||||
} else if constexpr (std::is_same_v<VectorType, f8_e5m2x16>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e5m2x4 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
if constexpr (std::is_same_v<AccumT, __half>) {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v4.e5m2x4 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
} else {
|
||||
asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e5m2x4 {%0,%1,%2,%3}, [%4];"
|
||||
: "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3])
|
||||
: "l"(ptr)
|
||||
: "memory");
|
||||
}
|
||||
} else {
|
||||
static_assert(dependentFalse<VectorType>, "Not supported type");
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace collective {
|
||||
|
||||
constexpr int MAX_NBLOCKS = 32;
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename AccumT = T>
|
||||
__global__ void __launch_bounds__(1024, 1)
|
||||
allreduceNvls([[maybe_unused]] mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>* memoryChannels,
|
||||
[[maybe_unused]] mscclpp::DeviceHandle<mscclpp::SwitchChannel>* multicast,
|
||||
@@ -58,8 +58,8 @@ __global__ void __launch_bounds__(1024, 1)
|
||||
T* src = (T*)multicastPtr->mcPtr;
|
||||
T* dst = (T*)multicastOutPtr->mcPtr;
|
||||
if (curBlockSize > 0) {
|
||||
handleMultiLoadReduceStore(src, dst, blockOffset + channelInOffset, blockOffset + channelOutOffset, curBlockSize,
|
||||
threadIdx.x, blockDim.x);
|
||||
handleMultiLoadReduceStore<T, AccumT>(src, dst, blockOffset + channelInOffset, blockOffset + channelOutOffset,
|
||||
curBlockSize, threadIdx.x, blockDim.x);
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x < nPeers) {
|
||||
@@ -90,9 +90,9 @@ struct NvlsAdapter {
|
||||
#endif
|
||||
{
|
||||
using ChannelType = DeviceHandle<mscclpp::BaseMemoryChannel>;
|
||||
allreduceNvls<T><<<nBlocks, nThreadsPerBlock, 0, stream>>>((ChannelType*)memoryChannels, nvlsChannels,
|
||||
nvlsOutChannels, channelInOffset, channelOutOffset,
|
||||
inputSize, rank, ipcDomainNranks);
|
||||
allreduceNvls<T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
|
||||
(ChannelType*)memoryChannels, nvlsChannels, nvlsOutChannels, channelInOffset, channelOutOffset, inputSize,
|
||||
rank, ipcDomainNranks);
|
||||
return cudaGetLastError();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ MSCCLPP_DEVICE_INLINE constexpr std::size_t calcVectorSize() {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename AccumT = T>
|
||||
MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* src, T* dst, size_t srcOffset, size_t dstOffset, size_t size,
|
||||
int tid, int nThreads) {
|
||||
// nvls can only handle 4 bytes alignment
|
||||
@@ -54,7 +54,7 @@ MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* src, T* dst, size_t src
|
||||
vectorType* src4 = (vectorType*)src;
|
||||
vectorType* dst4 = (vectorType*)dst;
|
||||
for (size_t idx = tid; idx < nVec; idx += nThreads) {
|
||||
auto val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce(src4 + srcOffset4 + idx);
|
||||
auto val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce<vectorType, AccumT>(src4 + srcOffset4 + idx);
|
||||
mscclpp::SwitchChannelDeviceHandle::multimemStore(val, dst4 + dstOffset4 + idx);
|
||||
}
|
||||
// handle rest of data
|
||||
@@ -64,7 +64,8 @@ MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* src, T* dst, size_t src
|
||||
const size_t startIdx = (srcOffset + processed) / sizeof(restVectorType);
|
||||
const size_t endIdx = (srcOffset + size) / sizeof(restVectorType);
|
||||
for (size_t idx = tid + startIdx; idx < endIdx; idx += nThreads) {
|
||||
auto val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce((restVectorType*)src + idx);
|
||||
auto val =
|
||||
mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce<restVectorType, AccumT>((restVectorType*)src + idx);
|
||||
mscclpp::SwitchChannelDeviceHandle::multimemStore(val, (restVectorType*)dst + idx);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user