From edc9c3875172ae34e9d5202745680bb9a10bb3df Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Sat, 14 Feb 2026 02:49:25 +0800 Subject: [PATCH] Support uint8 data type for Allreduce (#736) Support uint8 data type for Allreduce. Current limitation: uint8 is not supported for NVLS. Performance results with RCCL-test with MSCCLPP on MI300X: \# out-of-place in-place \# size count type redop root time algbw busbw #wrong time algbw busbw #wrong \# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s) 1024 | 512 | half | sum | -1 | 5.39 | 0.19 | 0.33 | 0 | 5.45 | 0.19 | 0.33 | 0 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- 2048 | 1024 | half | sum | -1 | 5.53 | 0.37 | 0.65 | 0 | 5.63 | 0.36 | 0.64 | 0 4096 | 2048 | half | sum | -1 | 5.55 | 0.74 | 1.29 | 0 | 5.56 | 0.74 | 1.29 | 0 8192 | 4096 | half | sum | -1 | 5.8 | 1.41 | 2.47 | 0 | 5.84 | 1.4 | 2.46 | 0 16384 | 8192 | half | sum | -1 | 6.57 | 2.49 | 4.36 | 0 | 6.56 | 2.5 | 4.37 | 0 32768 | 16384 | half | sum | -1 | 8.02 | 4.09 | 7.15 | 0 | 8.06 | 4.07 | 7.11 | 0 65536 | 32768 | half | sum | -1 | 8.77 | 7.47 | 13.07 | 0 | 8.82 | 7.43 | 13 | 0 131072 | 65536 | half | sum | -1 | 9.61 | 13.64 | 23.87 | 0 | 9.78 | 13.4 | 23.45 | 0 262144 | 131072 | half | sum | -1 | 11.68 | 22.44 | 39.27 | 0 | 12.1 | 21.67 | 37.93 | 0 524288 | 262144 | half | sum | -1 | 13.77 | 38.08 | 66.64 | 0 | 13.87 | 37.79 | 66.13 | 0 1048576 | 524288 | half | sum | -1 | 19.11 | 54.87 | 96.03 | 0 | 19.27 | 54.42 | 95.24 | 0 2097152 | 1048576 | half | sum | -1 | 24.1 | 87 | 152.26 | 0 | 24.24 | 86.52 | 151.41 | 0 4194304 | 2097152 | half | sum | -1 | 37.16 | 112.87 | 197.52 | 0 | 37.44 | 112.03 | 196.06 | 0 8388608 | 4194304 | half | sum | -1 | 61.53 | 136.33 | 238.58 | 0 | 61.68 | 135.99 | 237.99 | 0 16777216 | 8388608 | half | sum | -1 | 108.8 | 154.22 | 269.88 | 0 | 109.2 | 153.6 | 268.79 | 0 33554432 | 16777216 | half | sum | -1 | 197.8 | 169.68 | 296.94 | 0 | 198.6 | 168.92 | 295.61 | 0 67108864 | 33554432 | half | sum | -1 | 384.6 | 174.51 | 305.39 | 0 | 385.1 | 174.27 | 304.98 | 0 134217728 | 67108864 | half | sum | -1 | 754.1 | 177.99 | 311.48 | 0 | 754.9 | 177.78 | 311.12 | 0 268435456 | 134217728 | half | sum | -1 | 1491.8 | 179.94 | 314.89 | 0 | 1493.2 | 179.77 | 314.6 | 0 536870912 | 268435456 | half | sum | -1 | 2979.6 | 180.18 | 315.31 | 0 | 2983.9 | 179.92 | 314.87 | 0 \# out-of-place in-place \# size count type redop root time algbw busbw #wrong time algbw busbw #wrong \# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s) 1024 | 1024 | fp8_e4m3 | sum | -1 | 5.4 | 0.19 | 0.33 | 0 | 5.45 | 0.19 | 0.33 | 0 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- 2048 | 2048 | fp8_e4m3 | sum | -1 | 5.5 | 0.37 | 0.65 | 0 | 5.6 | 0.37 | 0.64 | 0 4096 | 4096 | fp8_e4m3 | sum | -1 | 5.61 | 0.73 | 1.28 | 0 | 5.68 | 0.72 | 1.26 | 0 8192 | 8192 | fp8_e4m3 | sum | -1 | 5.96 | 1.38 | 2.41 | 0 | 5.98 | 1.37 | 2.4 | 0 16384 | 16384 | fp8_e4m3 | sum | -1 | 6.49 | 2.52 | 4.42 | 0 | 6.58 | 2.49 | 4.36 | 0 32768 | 32768 | fp8_e4m3 | sum | -1 | 8.09 | 4.05 | 7.09 | 0 | 8.15 | 4.02 | 7.03 | 0 65536 | 65536 | fp8_e4m3 | sum | -1 | 8.58 | 7.64 | 13.37 | 0 | 8.7 | 7.53 | 13.18 | 0 131072 | 131072 | fp8_e4m3 | sum | -1 | 9.44 | 13.88 | 24.29 | 0 | 9.62 | 13.63 | 23.85 | 0 262144 | 262144 | fp8_e4m3 | sum | -1 | 10.12 | 25.9 | 45.32 | 0 | 10.37 | 25.27 | 44.22 | 0 524288 | 524288 | fp8_e4m3 | sum | -1 | 13.73 | 38.19 | 66.82 | 0 | 13.89 | 37.74 | 66.04 | 0 1048576 | 1048576 | fp8_e4m3 | sum | -1 | 18.66 | 56.2 | 98.34 | 0 | 18.92 | 55.41 | 96.97 | 0 2097152 | 2097152 | fp8_e4m3 | sum | -1 | 24.54 | 85.46 | 149.56 | 0 | 24.63 | 85.16 | 149.03 | 0 4194304 | 4194304 | fp8_e4m3 | sum | -1 | 37.79 | 110.98 | 194.21 | 0 | 38.05 | 110.22 | 192.88 | 0 8388608 | 8388608 | fp8_e4m3 | sum | -1 | 62.22 | 134.82 | 235.94 | 0 | 62.63 | 133.94 | 234.4 | 0 16777216 | 16777216 | fp8_e4m3 | sum | -1 | 109.9 | 152.62 | 267.09 | 0 | 110.4 | 151.9 | 265.83 | 0 33554432 | 33554432 | fp8_e4m3 | sum | -1 | 201.1 | 166.82 | 291.94 | 0 | 202.3 | 165.84 | 290.22 | 0 67108864 | 67108864 | fp8_e4m3 | sum | -1 | 390 | 172.06 | 301.11 | 0 | 390.2 | 171.99 | 300.99 | 0 134217728 | 134217728 | fp8_e4m3 | sum | -1 | 763.9 | 175.7 | 307.47 | 0 | 764.2 | 175.62 | 307.34 | 0 268435456 | 268435456 | fp8_e4m3 | sum | -1 | 1509.5 | 177.83 | 311.2 | 0 | 1510.1 | 177.76 | 311.08 | 0 536870912 | 536870912 | fp8_e4m3 | sum | -1 | 3010.2 | 178.35 | 312.11 | 0 | 3014.2 | 178.11 | 311.7 | 0 \# out-of-place in-place \# size count type redop root time algbw busbw #wrong time algbw busbw #wrong \# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s) 1024 | 1024 | fp8_e5m2 | sum | -1 | 5.41 | 0.19 | 0.33 | 0 | 5.44 | 0.19 | 0.33 | 0 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- 2048 | 2048 | fp8_e5m2 | sum | -1 | 5.5 | 0.37 | 0.65 | 0 | 5.67 | 0.36 | 0.63 | 0 4096 | 4096 | fp8_e5m2 | sum | -1 | 5.61 | 0.73 | 1.28 | 0 | 5.69 | 0.72 | 1.26 | 0 8192 | 8192 | fp8_e5m2 | sum | -1 | 5.96 | 1.37 | 2.4 | 0 | 6 | 1.36 | 2.39 | 0 16384 | 16384 | fp8_e5m2 | sum | -1 | 6.63 | 2.47 | 4.32 | 0 | 6.59 | 2.49 | 4.35 | 0 32768 | 32768 | fp8_e5m2 | sum | -1 | 8.07 | 4.06 | 7.1 | 0 | 8.16 | 4.02 | 7.03 | 0 65536 | 65536 | fp8_e5m2 | sum | -1 | 8.62 | 7.61 | 13.31 | 0 | 8.73 | 7.51 | 13.14 | 0 131072 | 131072 | fp8_e5m2 | sum | -1 | 9.43 | 13.9 | 24.33 | 0 | 9.6 | 13.66 | 23.9 | 0 262144 | 262144 | fp8_e5m2 | sum | -1 | 10.11 | 25.94 | 45.39 | 0 | 10.38 | 25.26 | 44.21 | 0 524288 | 524288 | fp8_e5m2 | sum | -1 | 13.73 | 38.19 | 66.84 | 0 | 13.87 | 37.79 | 66.13 | 0 1048576 | 1048576 | fp8_e5m2 | sum | -1 | 18.65 | 56.22 | 98.39 | 0 | 18.93 | 55.38 | 96.92 | 0 2097152 | 2097152 | fp8_e5m2 | sum | -1 | 24.54 | 85.47 | 149.57 | 0 | 24.63 | 85.16 | 149.03 | 0 4194304 | 4194304 | fp8_e5m2 | sum | -1 | 37.84 | 110.83 | 193.96 | 0 | 38.01 | 110.36 | 193.12 | 0 8388608 | 8388608 | fp8_e5m2 | sum | -1 | 62.32 | 134.61 | 235.58 | 0 | 62.55 | 134.12 | 234.71 | 0 16777216 | 16777216 | fp8_e5m2 | sum | -1 | 110 | 152.58 | 267.01 | 0 | 110.3 | 152.12 | 266.21 | 0 33554432 | 33554432 | fp8_e5m2 | sum | -1 | 201.1 | 166.9 | 292.07 | 0 | 201.8 | 166.26 | 290.96 | 0 67108864 | 67108864 | fp8_e5m2 | sum | -1 | 390 | 172.07 | 301.12 | 0 | 390.5 | 171.87 | 300.78 | 0 134217728 | 134217728 | fp8_e5m2 | sum | -1 | 763.9 | 175.69 | 307.46 | 0 | 764.5 | 175.56 | 307.23 | 0 268435456 | 268435456 | fp8_e5m2 | sum | -1 | 1509.4 | 177.84 | 311.22 | 0 | 1509.8 | 177.8 | 311.14 | 0 536870912 | 536870912 | fp8_e5m2 | sum | -1 | 3013 | 178.18 | 311.82 | 0 | 3018 | 177.89 | 311.31 | 0 \# out-of-place in-place \# size count type redop root time algbw busbw #wrong time algbw busbw #wrong \# (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s) 1024 | 1024 | uint8 | sum | -1 | 5.46 | 0.19 | 0.33 | 0 | 5.46 | 0.19 | 0.33 | 0 -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- 2048 | 2048 | uint8 | sum | -1 | 5.54 | 0.37 | 0.65 | 0 | 5.63 | 0.36 | 0.64 | 0 4096 | 4096 | uint8 | sum | -1 | 5.61 | 0.73 | 1.28 | 0 | 5.63 | 0.73 | 1.27 | 0 8192 | 8192 | uint8 | sum | -1 | 5.9 | 1.39 | 2.43 | 0 | 5.9 | 1.39 | 2.43 | 0 16384 | 16384 | uint8 | sum | -1 | 6.6 | 2.48 | 4.35 | 0 | 6.64 | 2.47 | 4.32 | 0 32768 | 32768 | uint8 | sum | -1 | 8.99 | 3.65 | 6.38 | 0 | 8.99 | 3.64 | 6.38 | 0 65536 | 65536 | uint8 | sum | -1 | 9.44 | 6.94 | 12.15 | 0 | 9.58 | 6.84 | 11.98 | 0 131072 | 131072 | uint8 | sum | -1 | 11.72 | 11.18 | 19.57 | 0 | 11.83 | 11.08 | 19.4 | 0 262144 | 262144 | uint8 | sum | -1 | 12.29 | 21.32 | 37.31 | 0 | 12.45 | 21.05 | 36.84 | 0 524288 | 524288 | uint8 | sum | -1 | 13.87 | 37.8 | 66.15 | 0 | 13.93 | 37.64 | 65.88 | 0 1048576 | 1048576 | uint8 | sum | -1 | 19.11 | 54.88 | 96.04 | 0 | 19.3 | 54.33 | 95.08 | 0 2097152 | 2097152 | uint8 | sum | -1 | 24.38 | 86.01 | 150.51 | 0 | 24.52 | 85.53 | 149.67 | 0 4194304 | 4194304 | uint8 | sum | -1 | 37.52 | 111.78 | 195.61 | 0 | 37.76 | 111.08 | 194.39 | 0 8388608 | 8388608 | uint8 | sum | -1 | 62.4 | 134.44 | 235.26 | 0 | 62.56 | 134.1 | 234.67 | 0 16777216 | 16777216 | uint8 | sum | -1 | 110.2 | 152.22 | 266.39 | 0 | 110.3 | 152.04 | 266.08 | 0 33554432 | 33554432 | uint8 | sum | -1 | 199.8 | 167.94 | 293.9 | 0 | 197.5 | 169.88 | 297.29 | 0 67108864 | 67108864 | uint8 | sum | -1 | 386.3 | 173.73 | 304.03 | 0 | 378.4 | 177.37 | 310.39 | 0 134217728 | 134217728 | uint8 | sum | -1 | 758 | 177.07 | 309.87 | 0 | 741.1 | 181.12 | 316.95 | 0 268435456 | 268435456 | uint8 | sum | -1 | 1500.1 | 178.95 | 313.16 | 0 | 1466.2 | 183.09 | 320.4 | 0 536870912 | 536870912 | uint8 | sum | -1 | 2991.7 | 179.45 | 314.04 | 0 | 2924.8 | 183.56 | 321.23 | 0 --------- Co-authored-by: Qinghua Zhou --- include/mscclpp/gpu_data_types.hpp | 39 +++++++ src/core/executor/execution_kernel.cu | 11 ++ src/core/include/execution_kernel.hpp | 100 ++++++++++-------- src/core/include/reduce_kernel.hpp | 14 +-- .../collectives/allreduce/allreduce_nvls.cu | 6 +- .../allreduce/allreduce_nvls_with_copy.cu | 22 ++-- .../allreduce/allreduce_nvls_with_copy_2.cu | 22 ++-- .../collectives/include/allreduce/common.hpp | 4 + src/ext/nccl/algorithm_selector.cc | 6 ++ src/ext/nccl/datatype_conversion.hpp | 5 + 10 files changed, 162 insertions(+), 67 deletions(-) diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp index 9e7747a8..5a99355f 100644 --- a/include/mscclpp/gpu_data_types.hpp +++ b/include/mscclpp/gpu_data_types.hpp @@ -75,6 +75,7 @@ enum class DataType { BFLOAT16, // bfloat16 precision. FP8_E4M3, // FP8 with E4M3 layout. FP8_E5M2, // FP8 with E5M2 layout. + UINT8, // 8-bit unsigned integer. }; /// Word array. @@ -154,12 +155,14 @@ DEFINE_VEC(f64x1, double, 1, double); DEFINE_VEC(i32x2, int32_t, 2, int2); DEFINE_VEC(u32x2, uint32_t, 2, uint2); +DEFINE_VEC(u8x2, uint8_t, 2, uint16_t); DEFINE_VEC(f32x2, float, 2, float2); DEFINE_VEC(f16x2, __half, 2, __half2); DEFINE_VEC(bf16x2, __bfloat16, 2, __bfloat162); DEFINE_VEC(i32x4, int32_t, 4, int4); DEFINE_VEC(u32x4, uint32_t, 4, uint4); +DEFINE_VEC(u8x4, uint8_t, 4, uint32_t); DEFINE_VEC(f32x4, float, 4, float4); DEFINE_VEC(f16x4, __half, 4, uint2); DEFINE_VEC(bf16x4, __bfloat16, 4, uint2); @@ -427,6 +430,20 @@ MSCCLPP_DEVICE_INLINE f8_e5m2x4 operator+(const f8_e5m2x4& a, const f8_e5m2x4& b } #endif // defined(__FP8_TYPES_EXIST__) +MSCCLPP_DEVICE_INLINE u8x4 operator+(const u8x4& a, const u8x4& b) { +#if defined(MSCCLPP_DEVICE_HIP) + // Optimized uint8_t x 4 sum using byte permute to avoid overflow between adjacent bytes + constexpr uint32_t even = 0x00ff00ffu; + uint32_t ua = a.storage; + uint32_t ub = b.storage; + uint32_t x = (ua & even) + (ub & even); + uint32_t y = (ua & ~even) + (ub & ~even); + return __byte_perm(x, y, 0x7250); +#else + return __vadd4(a.storage, b.storage); +#endif +} + template MSCCLPP_DEVICE_INLINE T min(const T& a, const T& b) { return (a < b ? a : b); @@ -450,6 +467,28 @@ MSCCLPP_DEVICE_INLINE bf16x2 min(const bf16x2& a, const bf16x2& b) { return __hmin2(a, b); } +template <> +MSCCLPP_DEVICE_INLINE u8x4 min(const u8x4& a, const u8x4& b) { +#if defined(MSCCLPP_DEVICE_HIP) + // Optimized uint8_t x 4 min using 9-bit arithmetic + constexpr uint32_t ones = 0x01010101u; + constexpr uint32_t even = 0x00ff00ffu; // even byte mask + uint32_t ua = a.storage; + uint32_t ub = b.storage; + // Use 9-bit arithmetic to compute d=a-b for each byte + uint32_t d0 = (ua & even) + (~ub & even) + ones; + uint32_t d1 = ((ua >> 8) & even) + (~(ub >> 8) & even) + ones; + // Move sign bit of each 9-bit delta into the least bit of origin byte + uint32_t s = __byte_perm(d0, d1, 0x7351) & ones; + // Broadcast least bit across whole byte + s *= 0xffu; + // Compose result by selecting bytes via: signbit(a-b)==1 ? a : b + return (ua & s) | (ub & ~s); +#else + return __vminu4(a.storage, b.storage); +#endif +} + #if defined(__FP8_TYPES_EXIST__) template <> MSCCLPP_DEVICE_INLINE __fp8_e4m3 min(const __fp8_e4m3& a, const __fp8_e4m3& b) { diff --git a/src/core/executor/execution_kernel.cu b/src/core/executor/execution_kernel.cu index 4b1b06bc..ceddf9b7 100644 --- a/src/core/executor/execution_kernel.cu +++ b/src/core/executor/execution_kernel.cu @@ -32,6 +32,17 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); #else ); +#endif + break; + case DataType::UINT8: + executionKernel<<>>( + rank, (uint8_t*)src, (uint8_t*)dst, (uint8_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores, + localMemoryIdBegin, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); #endif break; case DataType::FLOAT16: diff --git a/src/core/include/execution_kernel.hpp b/src/core/include/execution_kernel.hpp index 74283244..f2fad0d5 100644 --- a/src/core/include/execution_kernel.hpp +++ b/src/core/include/execution_kernel.hpp @@ -521,51 +521,56 @@ MSCCLPP_DEVICE_INLINE void handleCopy(const Operation& op, void* input, void* ou #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 template MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(const Operation& op, uint32_t offset, uint32_t unitSize) { - static_assert(sizeof(T) <= 8, "Only support type with size <= 8 bytes"); - const uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); - if (size <= 0) { + if constexpr (std::is_same_v) { + assert(false && "MULTI_LOAD_REDUCE_STORE is not supported for uint8_t data type"); return; - } - const uint32_t srcOffset = op.inputOffsets[0] + getOffset(op.nvlsInputBufferType, offset); - const uint32_t dstOffset = op.outputOffsets[0] + getOffset(op.nvlsOutputBufferType, offset); - assert(size % sizeof(T) == 0); - assert(srcOffset % sizeof(T) == 0); - assert(dstOffset % sizeof(T) == 0); - - T* src = (T*)nvlsChannels_[op.nvlsInputIndex].mcPtr; - T* dst = (T*)nvlsChannels_[op.nvlsOutputIndex].mcPtr; - if constexpr (std::is_same_v || std::is_same_v) { - const size_t nElem = size / sizeof(T); - const size_t srcOffsetElem = srcOffset / sizeof(T); - const size_t dstOffsetElem = dstOffset / sizeof(T); - VectorType* srcElem = reinterpret_cast*>(src + srcOffsetElem); - VectorType* dstElem = reinterpret_cast*>(dst + dstOffsetElem); - for (size_t idx = threadIdx.x; idx < nElem; idx += blockDim.x) { - auto val = SwitchChannelDeviceHandle::multimemLoadReduce(srcElem + idx); - SwitchChannelDeviceHandle::multimemStore(val, dstElem + idx); - } } else { - // handle data in 16-byte unit - using Type16 = mscclpp::VectorType; - const size_t nType16 = size / sizeof(Type16); - const size_t srcOffset16 = srcOffset / sizeof(Type16); - const size_t dstOffset16 = dstOffset / sizeof(Type16); - Type16* src16 = reinterpret_cast(src) + srcOffset16; - Type16* dst16 = reinterpret_cast(dst) + dstOffset16; - for (size_t idx = threadIdx.x; idx < nType16; idx += blockDim.x) { - Type16 val = SwitchChannelDeviceHandle::multimemLoadReduce(src16 + idx); - SwitchChannelDeviceHandle::multimemStore(val, dst16 + idx); + static_assert(sizeof(T) <= 8, "Only support type with size <= 8 bytes"); + const uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); + if (size <= 0) { + return; } - // handle rest of data - constexpr int RedBytes = (sizeof(T) == 8) ? 8 : 4; - using TypeRest = mscclpp::VectorType; - const size_t processed = nType16 * sizeof(Type16); - const size_t nRest = (size - processed) / sizeof(TypeRest); - TypeRest* srcR = reinterpret_cast(src + srcOffset + processed); - TypeRest* dstR = reinterpret_cast(dst + dstOffset + processed); - for (size_t idx = threadIdx.x; idx < nRest; idx += blockDim.x) { - TypeRest val = SwitchChannelDeviceHandle::multimemLoadReduce(srcR + idx); - SwitchChannelDeviceHandle::multimemStore(val, dstR + idx); + const uint32_t srcOffset = op.inputOffsets[0] + getOffset(op.nvlsInputBufferType, offset); + const uint32_t dstOffset = op.outputOffsets[0] + getOffset(op.nvlsOutputBufferType, offset); + assert(size % sizeof(T) == 0); + assert(srcOffset % sizeof(T) == 0); + assert(dstOffset % sizeof(T) == 0); + + T* src = (T*)nvlsChannels_[op.nvlsInputIndex].mcPtr; + T* dst = (T*)nvlsChannels_[op.nvlsOutputIndex].mcPtr; + if constexpr (std::is_same_v || std::is_same_v) { + const size_t nElem = size / sizeof(T); + const size_t srcOffsetElem = srcOffset / sizeof(T); + const size_t dstOffsetElem = dstOffset / sizeof(T); + VectorType* srcElem = reinterpret_cast*>(src + srcOffsetElem); + VectorType* dstElem = reinterpret_cast*>(dst + dstOffsetElem); + for (size_t idx = threadIdx.x; idx < nElem; idx += blockDim.x) { + auto val = SwitchChannelDeviceHandle::multimemLoadReduce(srcElem + idx); + SwitchChannelDeviceHandle::multimemStore(val, dstElem + idx); + } + } else { + // handle data in 16-byte unit + using Type16 = mscclpp::VectorType; + const size_t nType16 = size / sizeof(Type16); + const size_t srcOffset16 = srcOffset / sizeof(Type16); + const size_t dstOffset16 = dstOffset / sizeof(Type16); + Type16* src16 = reinterpret_cast(src) + srcOffset16; + Type16* dst16 = reinterpret_cast(dst) + dstOffset16; + for (size_t idx = threadIdx.x; idx < nType16; idx += blockDim.x) { + Type16 val = SwitchChannelDeviceHandle::multimemLoadReduce(src16 + idx); + SwitchChannelDeviceHandle::multimemStore(val, dst16 + idx); + } + // handle rest of data + constexpr int RedBytes = (sizeof(T) == 8) ? 8 : 4; + using TypeRest = mscclpp::VectorType; + const size_t processed = nType16 * sizeof(Type16); + const size_t nRest = (size - processed) / sizeof(TypeRest); + TypeRest* srcR = reinterpret_cast(src + srcOffset + processed); + TypeRest* dstR = reinterpret_cast(dst + dstOffset + processed); + for (size_t idx = threadIdx.x; idx < nRest; idx += blockDim.x) { + TypeRest val = SwitchChannelDeviceHandle::multimemLoadReduce(srcR + idx); + SwitchChannelDeviceHandle::multimemStore(val, dstR + idx); + } } } } @@ -894,6 +899,17 @@ class ExecutionKernel { #endif break; #endif // __FP8_TYPES_EXIST__ + case DataType::UINT8: + executionKernel<<>>( + rank, (uint8_t*)src, (uint8_t*)dst, (uint8_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores, + localMemoryIdBegin, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif + break; } } #else // !defined(MSCCLPP_DEVICE_HIP) diff --git a/src/core/include/reduce_kernel.hpp b/src/core/include/reduce_kernel.hpp index 00dc7714..fd9bd1e9 100644 --- a/src/core/include/reduce_kernel.hpp +++ b/src/core/include/reduce_kernel.hpp @@ -60,16 +60,18 @@ MSCCLPP_DEVICE_INLINE DataType cal_vector(const DataType& a, const DataType& b) static_assert(sizeof(DataType) >= 4, "DataType size must be at least 4 bytes"); using CompType = typename std::conditional_t< std::is_same_v, f16x2, - std::conditional_t, bf16x2, + std::conditional_t< + std::is_same_v, bf16x2, + std::conditional_t, u8x4, #if defined(__FP8_TYPES_EXIST__) - std::conditional_t, f8_e4m3x4, - std::conditional_t, f8_e5m2x4, + std::conditional_t, f8_e4m3x4, + std::conditional_t, f8_e5m2x4, #endif - T + T #if defined(__FP8_TYPES_EXIST__) - >>>>; + >>>>>; #else - >>; + >>>; #endif return cal_vector_helper(a, b); } diff --git a/src/ext/collectives/allreduce/allreduce_nvls.cu b/src/ext/collectives/allreduce/allreduce_nvls.cu index b07993a0..b73e1d27 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls.cu @@ -72,8 +72,12 @@ struct NvlsAdapter { mscclpp::DeviceHandle* nvlsOutChannels, size_t channelInOffset, size_t channelOutOffset, size_t, int rank, int nRanksPerNode, int, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { + // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) + if constexpr (std::is_same_v) { + return cudaErrorNotSupported; + } else #if (!defined(__CUDA_ARCH_SPECIFIC__) && !defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) || (__CUDA_ARCH__ < 1000) - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v) { return cudaErrorNotSupported; } else #endif diff --git a/src/ext/collectives/allreduce/allreduce_nvls_with_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_with_copy.cu index 033f3311..23d5ca4e 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_with_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_with_copy.cu @@ -114,18 +114,22 @@ struct NvlsWithCopyAdapter { DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, size_t scratchBufferSize, int rank, int nRanksPerNode, int, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { -#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS - if constexpr (std::is_same_v || std::is_same_v) { + // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) + if constexpr (std::is_same_v) { return cudaErrorNotSupported; } else +#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS + if constexpr (std::is_same_v || std::is_same_v) { + return cudaErrorNotSupported; + } else #endif - { - using ChannelType = DeviceHandle; - allreduce10<<>>(input, scratch, output, (ChannelType*)memoryChannels, - nvlsChannels, inputSize, scratchBufferSize, rank, - nRanksPerNode); - return cudaGetLastError(); - } + { + using ChannelType = DeviceHandle; + allreduce10<<>>(input, scratch, output, (ChannelType*)memoryChannels, + nvlsChannels, inputSize, scratchBufferSize, rank, + nRanksPerNode); + return cudaGetLastError(); + } } }; diff --git a/src/ext/collectives/allreduce/allreduce_nvls_with_copy_2.cu b/src/ext/collectives/allreduce/allreduce_nvls_with_copy_2.cu index 96aa9168..1d8a3478 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_with_copy_2.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_with_copy_2.cu @@ -151,18 +151,22 @@ struct NvlsWithCopy2Adapter { DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, size_t scratchBufferSize, int rank, int nRanksPerNode, int, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { -#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS - if constexpr (std::is_same_v || std::is_same_v) { + // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) + if constexpr (std::is_same_v) { return cudaErrorNotSupported; } else +#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS + if constexpr (std::is_same_v || std::is_same_v) { + return cudaErrorNotSupported; + } else #endif - { - using ChannelType = DeviceHandle; - allreduceNvlsWithCopy2 - <<>>(input, scratch, output, (ChannelType*)memoryChannels, nvlsChannels, - inputSize, scratchBufferSize, rank, nRanksPerNode); - return cudaGetLastError(); - } + { + using ChannelType = DeviceHandle; + allreduceNvlsWithCopy2 + <<>>(input, scratch, output, (ChannelType*)memoryChannels, + nvlsChannels, inputSize, scratchBufferSize, rank, nRanksPerNode); + return cudaGetLastError(); + } } }; diff --git a/src/ext/collectives/include/allreduce/common.hpp b/src/ext/collectives/include/allreduce/common.hpp index 4c28a24a..ab82417a 100644 --- a/src/ext/collectives/include/allreduce/common.hpp +++ b/src/ext/collectives/include/allreduce/common.hpp @@ -96,6 +96,8 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype) { #endif } else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) { return Adapter::call; + } else if (dtype == mscclpp::DataType::UINT8) { + return Adapter::call; } else { return nullptr; } @@ -116,6 +118,8 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype) { #endif } else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) { return Adapter::call; + } else if (dtype == mscclpp::DataType::UINT8) { + return Adapter::call; } else { return nullptr; } diff --git a/src/ext/nccl/algorithm_selector.cc b/src/ext/nccl/algorithm_selector.cc index 37126d5b..0e3b3cc1 100644 --- a/src/ext/nccl/algorithm_selector.cc +++ b/src/ext/nccl/algorithm_selector.cc @@ -13,6 +13,12 @@ namespace nccl { static bool isNvlsSupportedForDataType(const AlgorithmSelectorConfig& config, DataType dtype) { bool nvlsSupported = config.nvlsSupported; + + // NVLS does not support uint8_t (no hardware support for byte-level reduction) + if (dtype == DataType::UINT8) { + return false; + } + const bool isFp8 = dtype == DataType::FP8_E4M3 || dtype == DataType::FP8_E5M2; if (!isFp8) { diff --git a/src/ext/nccl/datatype_conversion.hpp b/src/ext/nccl/datatype_conversion.hpp index 8dfe6aab..bb315894 100644 --- a/src/ext/nccl/datatype_conversion.hpp +++ b/src/ext/nccl/datatype_conversion.hpp @@ -18,6 +18,8 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) { return mscclpp::DataType::INT32; case ncclUint32: return mscclpp::DataType::UINT32; + case ncclUint8: + return mscclpp::DataType::UINT8; case ncclFloat16: return mscclpp::DataType::FLOAT16; case ncclFloat32: @@ -38,6 +40,7 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) { // Get the size in bytes of a data type inline size_t getDataTypeSize(mscclpp::DataType dtype) { switch (dtype) { + case mscclpp::DataType::UINT8: case mscclpp::DataType::FP8_E4M3: case mscclpp::DataType::FP8_E5M2: return 1; @@ -59,6 +62,8 @@ static inline ncclDataType_t mscclppToNcclDataType(mscclpp::DataType dtype) { return ncclInt32; case mscclpp::DataType::UINT32: return ncclUint32; + case mscclpp::DataType::UINT8: + return ncclUint8; case mscclpp::DataType::FLOAT16: return ncclFloat16; case mscclpp::DataType::FLOAT32: