diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp index 9e7747a8..5a99355f 100644 --- a/include/mscclpp/gpu_data_types.hpp +++ b/include/mscclpp/gpu_data_types.hpp @@ -75,6 +75,7 @@ enum class DataType { BFLOAT16, // bfloat16 precision. FP8_E4M3, // FP8 with E4M3 layout. FP8_E5M2, // FP8 with E5M2 layout. + UINT8, // 8-bit unsigned integer. }; /// Word array. @@ -154,12 +155,14 @@ DEFINE_VEC(f64x1, double, 1, double); DEFINE_VEC(i32x2, int32_t, 2, int2); DEFINE_VEC(u32x2, uint32_t, 2, uint2); +DEFINE_VEC(u8x2, uint8_t, 2, uint16_t); DEFINE_VEC(f32x2, float, 2, float2); DEFINE_VEC(f16x2, __half, 2, __half2); DEFINE_VEC(bf16x2, __bfloat16, 2, __bfloat162); DEFINE_VEC(i32x4, int32_t, 4, int4); DEFINE_VEC(u32x4, uint32_t, 4, uint4); +DEFINE_VEC(u8x4, uint8_t, 4, uint32_t); DEFINE_VEC(f32x4, float, 4, float4); DEFINE_VEC(f16x4, __half, 4, uint2); DEFINE_VEC(bf16x4, __bfloat16, 4, uint2); @@ -427,6 +430,20 @@ MSCCLPP_DEVICE_INLINE f8_e5m2x4 operator+(const f8_e5m2x4& a, const f8_e5m2x4& b } #endif // defined(__FP8_TYPES_EXIST__) +MSCCLPP_DEVICE_INLINE u8x4 operator+(const u8x4& a, const u8x4& b) { +#if defined(MSCCLPP_DEVICE_HIP) + // Optimized uint8_t x 4 sum using byte permute to avoid overflow between adjacent bytes + constexpr uint32_t even = 0x00ff00ffu; + uint32_t ua = a.storage; + uint32_t ub = b.storage; + uint32_t x = (ua & even) + (ub & even); + uint32_t y = (ua & ~even) + (ub & ~even); + return __byte_perm(x, y, 0x7250); +#else + return __vadd4(a.storage, b.storage); +#endif +} + template MSCCLPP_DEVICE_INLINE T min(const T& a, const T& b) { return (a < b ? a : b); @@ -450,6 +467,28 @@ MSCCLPP_DEVICE_INLINE bf16x2 min(const bf16x2& a, const bf16x2& b) { return __hmin2(a, b); } +template <> +MSCCLPP_DEVICE_INLINE u8x4 min(const u8x4& a, const u8x4& b) { +#if defined(MSCCLPP_DEVICE_HIP) + // Optimized uint8_t x 4 min using 9-bit arithmetic + constexpr uint32_t ones = 0x01010101u; + constexpr uint32_t even = 0x00ff00ffu; // even byte mask + uint32_t ua = a.storage; + uint32_t ub = b.storage; + // Use 9-bit arithmetic to compute d=a-b for each byte + uint32_t d0 = (ua & even) + (~ub & even) + ones; + uint32_t d1 = ((ua >> 8) & even) + (~(ub >> 8) & even) + ones; + // Move sign bit of each 9-bit delta into the least bit of origin byte + uint32_t s = __byte_perm(d0, d1, 0x7351) & ones; + // Broadcast least bit across whole byte + s *= 0xffu; + // Compose result by selecting bytes via: signbit(a-b)==1 ? a : b + return (ua & s) | (ub & ~s); +#else + return __vminu4(a.storage, b.storage); +#endif +} + #if defined(__FP8_TYPES_EXIST__) template <> MSCCLPP_DEVICE_INLINE __fp8_e4m3 min(const __fp8_e4m3& a, const __fp8_e4m3& b) { diff --git a/src/core/executor/execution_kernel.cu b/src/core/executor/execution_kernel.cu index 4b1b06bc..ceddf9b7 100644 --- a/src/core/executor/execution_kernel.cu +++ b/src/core/executor/execution_kernel.cu @@ -32,6 +32,17 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); #else ); +#endif + break; + case DataType::UINT8: + executionKernel<<>>( + rank, (uint8_t*)src, (uint8_t*)dst, (uint8_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores, + localMemoryIdBegin, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); #endif break; case DataType::FLOAT16: diff --git a/src/core/include/execution_kernel.hpp b/src/core/include/execution_kernel.hpp index 74283244..f2fad0d5 100644 --- a/src/core/include/execution_kernel.hpp +++ b/src/core/include/execution_kernel.hpp @@ -521,51 +521,56 @@ MSCCLPP_DEVICE_INLINE void handleCopy(const Operation& op, void* input, void* ou #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 template MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(const Operation& op, uint32_t offset, uint32_t unitSize) { - static_assert(sizeof(T) <= 8, "Only support type with size <= 8 bytes"); - const uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); - if (size <= 0) { + if constexpr (std::is_same_v) { + assert(false && "MULTI_LOAD_REDUCE_STORE is not supported for uint8_t data type"); return; - } - const uint32_t srcOffset = op.inputOffsets[0] + getOffset(op.nvlsInputBufferType, offset); - const uint32_t dstOffset = op.outputOffsets[0] + getOffset(op.nvlsOutputBufferType, offset); - assert(size % sizeof(T) == 0); - assert(srcOffset % sizeof(T) == 0); - assert(dstOffset % sizeof(T) == 0); - - T* src = (T*)nvlsChannels_[op.nvlsInputIndex].mcPtr; - T* dst = (T*)nvlsChannels_[op.nvlsOutputIndex].mcPtr; - if constexpr (std::is_same_v || std::is_same_v) { - const size_t nElem = size / sizeof(T); - const size_t srcOffsetElem = srcOffset / sizeof(T); - const size_t dstOffsetElem = dstOffset / sizeof(T); - VectorType* srcElem = reinterpret_cast*>(src + srcOffsetElem); - VectorType* dstElem = reinterpret_cast*>(dst + dstOffsetElem); - for (size_t idx = threadIdx.x; idx < nElem; idx += blockDim.x) { - auto val = SwitchChannelDeviceHandle::multimemLoadReduce(srcElem + idx); - SwitchChannelDeviceHandle::multimemStore(val, dstElem + idx); - } } else { - // handle data in 16-byte unit - using Type16 = mscclpp::VectorType; - const size_t nType16 = size / sizeof(Type16); - const size_t srcOffset16 = srcOffset / sizeof(Type16); - const size_t dstOffset16 = dstOffset / sizeof(Type16); - Type16* src16 = reinterpret_cast(src) + srcOffset16; - Type16* dst16 = reinterpret_cast(dst) + dstOffset16; - for (size_t idx = threadIdx.x; idx < nType16; idx += blockDim.x) { - Type16 val = SwitchChannelDeviceHandle::multimemLoadReduce(src16 + idx); - SwitchChannelDeviceHandle::multimemStore(val, dst16 + idx); + static_assert(sizeof(T) <= 8, "Only support type with size <= 8 bytes"); + const uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); + if (size <= 0) { + return; } - // handle rest of data - constexpr int RedBytes = (sizeof(T) == 8) ? 8 : 4; - using TypeRest = mscclpp::VectorType; - const size_t processed = nType16 * sizeof(Type16); - const size_t nRest = (size - processed) / sizeof(TypeRest); - TypeRest* srcR = reinterpret_cast(src + srcOffset + processed); - TypeRest* dstR = reinterpret_cast(dst + dstOffset + processed); - for (size_t idx = threadIdx.x; idx < nRest; idx += blockDim.x) { - TypeRest val = SwitchChannelDeviceHandle::multimemLoadReduce(srcR + idx); - SwitchChannelDeviceHandle::multimemStore(val, dstR + idx); + const uint32_t srcOffset = op.inputOffsets[0] + getOffset(op.nvlsInputBufferType, offset); + const uint32_t dstOffset = op.outputOffsets[0] + getOffset(op.nvlsOutputBufferType, offset); + assert(size % sizeof(T) == 0); + assert(srcOffset % sizeof(T) == 0); + assert(dstOffset % sizeof(T) == 0); + + T* src = (T*)nvlsChannels_[op.nvlsInputIndex].mcPtr; + T* dst = (T*)nvlsChannels_[op.nvlsOutputIndex].mcPtr; + if constexpr (std::is_same_v || std::is_same_v) { + const size_t nElem = size / sizeof(T); + const size_t srcOffsetElem = srcOffset / sizeof(T); + const size_t dstOffsetElem = dstOffset / sizeof(T); + VectorType* srcElem = reinterpret_cast*>(src + srcOffsetElem); + VectorType* dstElem = reinterpret_cast*>(dst + dstOffsetElem); + for (size_t idx = threadIdx.x; idx < nElem; idx += blockDim.x) { + auto val = SwitchChannelDeviceHandle::multimemLoadReduce(srcElem + idx); + SwitchChannelDeviceHandle::multimemStore(val, dstElem + idx); + } + } else { + // handle data in 16-byte unit + using Type16 = mscclpp::VectorType; + const size_t nType16 = size / sizeof(Type16); + const size_t srcOffset16 = srcOffset / sizeof(Type16); + const size_t dstOffset16 = dstOffset / sizeof(Type16); + Type16* src16 = reinterpret_cast(src) + srcOffset16; + Type16* dst16 = reinterpret_cast(dst) + dstOffset16; + for (size_t idx = threadIdx.x; idx < nType16; idx += blockDim.x) { + Type16 val = SwitchChannelDeviceHandle::multimemLoadReduce(src16 + idx); + SwitchChannelDeviceHandle::multimemStore(val, dst16 + idx); + } + // handle rest of data + constexpr int RedBytes = (sizeof(T) == 8) ? 8 : 4; + using TypeRest = mscclpp::VectorType; + const size_t processed = nType16 * sizeof(Type16); + const size_t nRest = (size - processed) / sizeof(TypeRest); + TypeRest* srcR = reinterpret_cast(src + srcOffset + processed); + TypeRest* dstR = reinterpret_cast(dst + dstOffset + processed); + for (size_t idx = threadIdx.x; idx < nRest; idx += blockDim.x) { + TypeRest val = SwitchChannelDeviceHandle::multimemLoadReduce(srcR + idx); + SwitchChannelDeviceHandle::multimemStore(val, dstR + idx); + } } } } @@ -894,6 +899,17 @@ class ExecutionKernel { #endif break; #endif // __FP8_TYPES_EXIST__ + case DataType::UINT8: + executionKernel<<>>( + rank, (uint8_t*)src, (uint8_t*)dst, (uint8_t*)scratch, scratchOffset, scratchChunkSize, plan, semaphores, + localMemoryIdBegin, flag +#if defined(ENABLE_NPKIT) + , + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); +#else + ); +#endif + break; } } #else // !defined(MSCCLPP_DEVICE_HIP) diff --git a/src/core/include/reduce_kernel.hpp b/src/core/include/reduce_kernel.hpp index 00dc7714..fd9bd1e9 100644 --- a/src/core/include/reduce_kernel.hpp +++ b/src/core/include/reduce_kernel.hpp @@ -60,16 +60,18 @@ MSCCLPP_DEVICE_INLINE DataType cal_vector(const DataType& a, const DataType& b) static_assert(sizeof(DataType) >= 4, "DataType size must be at least 4 bytes"); using CompType = typename std::conditional_t< std::is_same_v, f16x2, - std::conditional_t, bf16x2, + std::conditional_t< + std::is_same_v, bf16x2, + std::conditional_t, u8x4, #if defined(__FP8_TYPES_EXIST__) - std::conditional_t, f8_e4m3x4, - std::conditional_t, f8_e5m2x4, + std::conditional_t, f8_e4m3x4, + std::conditional_t, f8_e5m2x4, #endif - T + T #if defined(__FP8_TYPES_EXIST__) - >>>>; + >>>>>; #else - >>; + >>>; #endif return cal_vector_helper(a, b); } diff --git a/src/ext/collectives/allreduce/allreduce_nvls.cu b/src/ext/collectives/allreduce/allreduce_nvls.cu index b07993a0..b73e1d27 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls.cu @@ -72,8 +72,12 @@ struct NvlsAdapter { mscclpp::DeviceHandle* nvlsOutChannels, size_t channelInOffset, size_t channelOutOffset, size_t, int rank, int nRanksPerNode, int, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { + // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) + if constexpr (std::is_same_v) { + return cudaErrorNotSupported; + } else #if (!defined(__CUDA_ARCH_SPECIFIC__) && !defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) || (__CUDA_ARCH__ < 1000) - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v) { return cudaErrorNotSupported; } else #endif diff --git a/src/ext/collectives/allreduce/allreduce_nvls_with_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_with_copy.cu index 033f3311..23d5ca4e 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_with_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_with_copy.cu @@ -114,18 +114,22 @@ struct NvlsWithCopyAdapter { DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, size_t scratchBufferSize, int rank, int nRanksPerNode, int, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { -#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS - if constexpr (std::is_same_v || std::is_same_v) { + // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) + if constexpr (std::is_same_v) { return cudaErrorNotSupported; } else +#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS + if constexpr (std::is_same_v || std::is_same_v) { + return cudaErrorNotSupported; + } else #endif - { - using ChannelType = DeviceHandle; - allreduce10<<>>(input, scratch, output, (ChannelType*)memoryChannels, - nvlsChannels, inputSize, scratchBufferSize, rank, - nRanksPerNode); - return cudaGetLastError(); - } + { + using ChannelType = DeviceHandle; + allreduce10<<>>(input, scratch, output, (ChannelType*)memoryChannels, + nvlsChannels, inputSize, scratchBufferSize, rank, + nRanksPerNode); + return cudaGetLastError(); + } } }; diff --git a/src/ext/collectives/allreduce/allreduce_nvls_with_copy_2.cu b/src/ext/collectives/allreduce/allreduce_nvls_with_copy_2.cu index 96aa9168..1d8a3478 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_with_copy_2.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_with_copy_2.cu @@ -151,18 +151,22 @@ struct NvlsWithCopy2Adapter { DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, size_t scratchBufferSize, int rank, int nRanksPerNode, int, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { -#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS - if constexpr (std::is_same_v || std::is_same_v) { + // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) + if constexpr (std::is_same_v) { return cudaErrorNotSupported; } else +#if defined(__CUDA_ARCH__) // Skip the __CUDA_ARCH__ < 1000 since FP8 has not been supported for NVLS + if constexpr (std::is_same_v || std::is_same_v) { + return cudaErrorNotSupported; + } else #endif - { - using ChannelType = DeviceHandle; - allreduceNvlsWithCopy2 - <<>>(input, scratch, output, (ChannelType*)memoryChannels, nvlsChannels, - inputSize, scratchBufferSize, rank, nRanksPerNode); - return cudaGetLastError(); - } + { + using ChannelType = DeviceHandle; + allreduceNvlsWithCopy2 + <<>>(input, scratch, output, (ChannelType*)memoryChannels, + nvlsChannels, inputSize, scratchBufferSize, rank, nRanksPerNode); + return cudaGetLastError(); + } } }; diff --git a/src/ext/collectives/include/allreduce/common.hpp b/src/ext/collectives/include/allreduce/common.hpp index 4c28a24a..ab82417a 100644 --- a/src/ext/collectives/include/allreduce/common.hpp +++ b/src/ext/collectives/include/allreduce/common.hpp @@ -96,6 +96,8 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype) { #endif } else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) { return Adapter::call; + } else if (dtype == mscclpp::DataType::UINT8) { + return Adapter::call; } else { return nullptr; } @@ -116,6 +118,8 @@ AllreduceFunc dispatch(ReduceOp op, mscclpp::DataType dtype) { #endif } else if (dtype == mscclpp::DataType::INT32 || dtype == mscclpp::DataType::UINT32) { return Adapter::call; + } else if (dtype == mscclpp::DataType::UINT8) { + return Adapter::call; } else { return nullptr; } diff --git a/src/ext/nccl/algorithm_selector.cc b/src/ext/nccl/algorithm_selector.cc index 37126d5b..0e3b3cc1 100644 --- a/src/ext/nccl/algorithm_selector.cc +++ b/src/ext/nccl/algorithm_selector.cc @@ -13,6 +13,12 @@ namespace nccl { static bool isNvlsSupportedForDataType(const AlgorithmSelectorConfig& config, DataType dtype) { bool nvlsSupported = config.nvlsSupported; + + // NVLS does not support uint8_t (no hardware support for byte-level reduction) + if (dtype == DataType::UINT8) { + return false; + } + const bool isFp8 = dtype == DataType::FP8_E4M3 || dtype == DataType::FP8_E5M2; if (!isFp8) { diff --git a/src/ext/nccl/datatype_conversion.hpp b/src/ext/nccl/datatype_conversion.hpp index 8dfe6aab..bb315894 100644 --- a/src/ext/nccl/datatype_conversion.hpp +++ b/src/ext/nccl/datatype_conversion.hpp @@ -18,6 +18,8 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) { return mscclpp::DataType::INT32; case ncclUint32: return mscclpp::DataType::UINT32; + case ncclUint8: + return mscclpp::DataType::UINT8; case ncclFloat16: return mscclpp::DataType::FLOAT16; case ncclFloat32: @@ -38,6 +40,7 @@ inline mscclpp::DataType ncclDataTypeToMscclpp(ncclDataType_t dtype) { // Get the size in bytes of a data type inline size_t getDataTypeSize(mscclpp::DataType dtype) { switch (dtype) { + case mscclpp::DataType::UINT8: case mscclpp::DataType::FP8_E4M3: case mscclpp::DataType::FP8_E5M2: return 1; @@ -59,6 +62,8 @@ static inline ncclDataType_t mscclppToNcclDataType(mscclpp::DataType dtype) { return ncclInt32; case mscclpp::DataType::UINT32: return ncclUint32; + case mscclpp::DataType::UINT8: + return ncclUint8; case mscclpp::DataType::FLOAT16: return ncclFloat16; case mscclpp::DataType::FLOAT32: