mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-04-19 22:39:11 +00:00
bfloat16 support (#336)
* Add bfloat16 support for executor and NCCL interface * Changed `gpu_data_types.hpp` into an internal header file
This commit is contained in:
@@ -11,7 +11,7 @@ endif()
|
||||
add_library(mscclpp_nccl_obj OBJECT)
|
||||
target_sources(mscclpp_nccl_obj PRIVATE ${SOURCES})
|
||||
target_sources(mscclpp_nccl_obj PUBLIC FILE_SET HEADERS FILES ${HEADERS})
|
||||
target_include_directories(mscclpp_nccl_obj PRIVATE include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
target_include_directories(mscclpp_nccl_obj PRIVATE include ${PROJECT_SOURCE_DIR}/src/include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
|
||||
target_link_libraries(mscclpp_nccl_obj PRIVATE ${GPU_LIBRARIES} PUBLIC mscclpp_obj)
|
||||
set_target_properties(mscclpp_nccl_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
|
||||
if(USE_CUDA)
|
||||
|
||||
@@ -7,12 +7,12 @@
|
||||
#include <mscclpp/concurrency_device.hpp>
|
||||
#include <mscclpp/core.hpp>
|
||||
#include <mscclpp/gpu.hpp>
|
||||
#include <mscclpp/gpu_data_types.hpp>
|
||||
#include <mscclpp/packet_device.hpp>
|
||||
#include <mscclpp/sm_channel.hpp>
|
||||
#include <mscclpp/sm_channel_device.hpp>
|
||||
|
||||
#include "common.hpp"
|
||||
#include "gpu_data_types.hpp"
|
||||
|
||||
__device__ mscclpp::DeviceSyncer deviceSyncer;
|
||||
|
||||
@@ -38,6 +38,11 @@ __forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
|
||||
int4 ret;
|
||||
@@ -58,6 +63,11 @@ __forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ int4 add_vectors<__bfloat16>(int4 a, int4 b) {
|
||||
return add_vectors_helper<__bfloat162>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
|
||||
uint2 ret;
|
||||
@@ -76,6 +86,11 @@ __forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ uint2 add_vectors<__bfloat16>(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<__bfloat162>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
|
||||
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
|
||||
@@ -91,6 +106,11 @@ __forceinline__ __device__ int add_vectors<__half>(int a, int b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ int add_vectors<__bfloat16>(int a, int b) {
|
||||
return add_vectors_helper<__bfloat162>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ uint32_t add_vectors_helper(uint32_t a, uint32_t b) {
|
||||
return bit_cast<uint32_t, T>(add_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
|
||||
@@ -106,6 +126,11 @@ __forceinline__ __device__ uint32_t add_vectors<__half>(uint32_t a, uint32_t b)
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__forceinline__ __device__ uint32_t add_vectors<__bfloat16>(uint32_t a, uint32_t b) {
|
||||
return add_vectors_helper<__bfloat162>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem, int blockId, int nBlocks) {
|
||||
size_t nInt4 = nElem / 4;
|
||||
|
||||
@@ -262,6 +262,11 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
|
||||
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
|
||||
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
break;
|
||||
case ncclBfloat16:
|
||||
CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff,
|
||||
smChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
|
||||
comm->comm->bootstrap()->getNranks(), count, stream));
|
||||
break;
|
||||
case ncclInt32:
|
||||
case ncclUint32:
|
||||
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
|
||||
@@ -498,6 +503,10 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
|
||||
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32,
|
||||
1024, *plan, stream, mscclpp::PacketType::LL8);
|
||||
break;
|
||||
case ncclBfloat16:
|
||||
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
|
||||
mscclpp::DataType::BFLOAT16, 1024, *plan, stream, mscclpp::PacketType::LL8);
|
||||
break;
|
||||
case ncclInt32:
|
||||
case ncclUint32:
|
||||
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, 1024,
|
||||
|
||||
@@ -15,6 +15,7 @@ enum class DataType {
|
||||
UINT32,
|
||||
FLOAT16,
|
||||
FLOAT32,
|
||||
BFLOAT16,
|
||||
};
|
||||
|
||||
enum class PacketType {
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include <type_traits>
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
#include <mscclpp/gpu_data_types.hpp>
|
||||
#include <cuda_fp16.h>
|
||||
#endif // defined(MSCCLPP_DEVICE_CUDA)
|
||||
|
||||
#include "device.hpp"
|
||||
|
||||
@@ -16,7 +16,8 @@ void register_executor(nb::module_& m) {
|
||||
.value("int32", DataType::INT32)
|
||||
.value("uint32", DataType::UINT32)
|
||||
.value("float16", DataType::FLOAT16)
|
||||
.value("float32", DataType::FLOAT32);
|
||||
.value("float32", DataType::FLOAT32)
|
||||
.value("bfloat16", DataType::BFLOAT16);
|
||||
|
||||
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);
|
||||
|
||||
|
||||
@@ -49,6 +49,16 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::BFLOAT16:
|
||||
executionKernel<__bfloat16><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "execution_common.hpp"
|
||||
|
||||
#if defined(MSCCLPP_DEVICE_COMPILE)
|
||||
#include <mscclpp/gpu_data_types.hpp>
|
||||
#include "gpu_data_types.hpp"
|
||||
|
||||
namespace {
|
||||
template <typename To, typename From>
|
||||
@@ -60,6 +60,11 @@ MSCCLPP_DEVICE_INLINE int4 add_vectors<__half>(int4 a, int4 b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE int4 add_vectors<__bfloat16>(int4 a, int4 b) {
|
||||
return add_vectors_helper<__bfloat162>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MSCCLPP_DEVICE_INLINE uint2 add_vectors_helper(uint2 a, uint2 b) {
|
||||
uint2 ret;
|
||||
@@ -78,6 +83,11 @@ MSCCLPP_DEVICE_INLINE __attribute__((unused)) uint2 add_vectors<__half>(uint2 a,
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE __attribute__((unused)) uint2 add_vectors<__bfloat16>(uint2 a, uint2 b) {
|
||||
return add_vectors_helper<__bfloat162>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MSCCLPP_DEVICE_INLINE int add_vectors_helper(int a, int b) {
|
||||
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
|
||||
@@ -93,6 +103,11 @@ MSCCLPP_DEVICE_INLINE __attribute__((unused)) int add_vectors<__half>(int a, int
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE __attribute__((unused)) int add_vectors<__bfloat16>(int a, int b) {
|
||||
return add_vectors_helper<__bfloat162>(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MSCCLPP_DEVICE_INLINE uint32_t add_vectors_helper(uint32_t a, uint32_t b) {
|
||||
return bit_cast<uint32_t, T>(add_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
|
||||
@@ -108,6 +123,11 @@ MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__half>(uint32_t a, uint32_t b) {
|
||||
return add_vectors_helper<__half2>(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__bfloat16>(uint32_t a, uint32_t b) {
|
||||
return add_vectors_helper<__bfloat162>(a, b);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
#endif // defined(MSCCLPP_DEVICE_COMPILE)
|
||||
|
||||
@@ -502,6 +522,16 @@ class ExecutionKernel {
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
case DataType::BFLOAT16:
|
||||
executionKernel<__bfloat16, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
|
||||
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
|
||||
#if defined(ENABLE_NPKIT)
|
||||
,
|
||||
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
|
||||
#else
|
||||
);
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
using __bfloat16 = __hip_bfloat16;
|
||||
using __bfloat162 = __hip_bfloat162;
|
||||
#define __CUDA_BF16_TYPES_EXIST__
|
||||
|
||||
#else
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
@@ -19,6 +23,9 @@
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
|
||||
using __bfloat16 = __nv_bfloat16;
|
||||
using __bfloat162 = __nv_bfloat162;
|
||||
|
||||
#endif
|
||||
|
||||
#endif // MSCCLPP_GPU_DATA_TYPES_HPP_
|
||||
Reference in New Issue
Block a user