mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
multimem.ld_reduce on FP8 inputs accumulates in FP32 by default. The
ISA also exposes an .acc::f16 variant that keeps the reduction in
FP16, which is faster but lower precision. Plumb AccumT through:
- include/mscclpp/switch_channel_device.hpp:
Extend SwitchChannelDeviceHandle::multimemLoadReduce with an optional
AccumT template parameter. When VectorType is one of the FP8 vector
types (f8_e4m3x{4,8,16} / f8_e5m2x{4,8,16}) and AccumT is __half,
emit the .acc::f16 form of the instruction; otherwise unchanged.
- src/ext/collectives/include/allreduce/common.hpp:
Make handleMultiLoadReduceStore template on AccumT and forward it to
multimemLoadReduce<vectorType, AccumT>(...).
- src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu:
Template allreduceNvls and NvlsAdapter on AccumT and forward to
handleMultiLoadReduceStore<T, AccumT>; the existing dispatch<>
machinery already plumbs AccumT through from the algorithm context.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>