bfloat16 support (#336)

* Add bfloat16 support for executor and NCCL interface
* Changed `gpu_data_types.hpp` into an internal header file
This commit is contained in:
Changho Hwang
2024-08-12 15:41:58 -07:00
committed by GitHub
parent faadc75649
commit 8c6fb429e9
9 changed files with 88 additions and 5 deletions

View File

@@ -11,7 +11,7 @@ endif()
add_library(mscclpp_nccl_obj OBJECT)
target_sources(mscclpp_nccl_obj PRIVATE ${SOURCES})
target_sources(mscclpp_nccl_obj PUBLIC FILE_SET HEADERS FILES ${HEADERS})
target_include_directories(mscclpp_nccl_obj PRIVATE include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
target_include_directories(mscclpp_nccl_obj PRIVATE include ${PROJECT_SOURCE_DIR}/src/include SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
target_link_libraries(mscclpp_nccl_obj PRIVATE ${GPU_LIBRARIES} PUBLIC mscclpp_obj)
set_target_properties(mscclpp_nccl_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
if(USE_CUDA)

View File

@@ -7,12 +7,12 @@
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu.hpp>
#include <mscclpp/gpu_data_types.hpp>
#include <mscclpp/packet_device.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include "common.hpp"
#include "gpu_data_types.hpp"
__device__ mscclpp::DeviceSyncer deviceSyncer;
@@ -38,6 +38,11 @@ __forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
return __hadd2(a, b);
}
template <>
__forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
return __hadd2(a, b);
}
template <typename T>
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
int4 ret;
@@ -58,6 +63,11 @@ __forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
return add_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ int4 add_vectors<__bfloat16>(int4 a, int4 b) {
return add_vectors_helper<__bfloat162>(a, b);
}
template <typename T>
__forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
uint2 ret;
@@ -76,6 +86,11 @@ __forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
return add_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ uint2 add_vectors<__bfloat16>(uint2 a, uint2 b) {
return add_vectors_helper<__bfloat162>(a, b);
}
template <typename T>
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
@@ -91,6 +106,11 @@ __forceinline__ __device__ int add_vectors<__half>(int a, int b) {
return add_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ int add_vectors<__bfloat16>(int a, int b) {
return add_vectors_helper<__bfloat162>(a, b);
}
template <typename T>
__forceinline__ __device__ uint32_t add_vectors_helper(uint32_t a, uint32_t b) {
return bit_cast<uint32_t, T>(add_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
@@ -106,6 +126,11 @@ __forceinline__ __device__ uint32_t add_vectors<__half>(uint32_t a, uint32_t b)
return add_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ uint32_t add_vectors<__bfloat16>(uint32_t a, uint32_t b) {
return add_vectors_helper<__bfloat162>(a, b);
}
template <typename T>
__forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem, int blockId, int nBlocks) {
size_t nInt4 = nElem / 4;

View File

@@ -262,6 +262,11 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclBfloat16:
CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff,
smChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
@@ -498,6 +503,10 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32,
1024, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
mscclpp::DataType::BFLOAT16, 1024, *plan, stream, mscclpp::PacketType::LL8);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, 1024,

View File

@@ -15,6 +15,7 @@ enum class DataType {
UINT32,
FLOAT16,
FLOAT32,
BFLOAT16,
};
enum class PacketType {

View File

@@ -8,7 +8,7 @@
#include <type_traits>
#if defined(MSCCLPP_DEVICE_CUDA)
#include <mscclpp/gpu_data_types.hpp>
#include <cuda_fp16.h>
#endif // defined(MSCCLPP_DEVICE_CUDA)
#include "device.hpp"

View File

@@ -16,7 +16,8 @@ void register_executor(nb::module_& m) {
.value("int32", DataType::INT32)
.value("uint32", DataType::UINT32)
.value("float16", DataType::FLOAT16)
.value("float32", DataType::FLOAT32);
.value("float32", DataType::FLOAT32)
.value("bfloat16", DataType::BFLOAT16);
nb::enum_<PacketType>(m, "PacketType").value("LL8", PacketType::LL8).value("LL16", PacketType::LL16);

View File

@@ -49,6 +49,16 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
case DataType::BFLOAT16:
executionKernel<__bfloat16><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
}

View File

@@ -15,7 +15,7 @@
#include "execution_common.hpp"
#if defined(MSCCLPP_DEVICE_COMPILE)
#include <mscclpp/gpu_data_types.hpp>
#include "gpu_data_types.hpp"
namespace {
template <typename To, typename From>
@@ -60,6 +60,11 @@ MSCCLPP_DEVICE_INLINE int4 add_vectors<__half>(int4 a, int4 b) {
return add_vectors_helper<__half2>(a, b);
}
template <>
MSCCLPP_DEVICE_INLINE int4 add_vectors<__bfloat16>(int4 a, int4 b) {
return add_vectors_helper<__bfloat162>(a, b);
}
template <typename T>
MSCCLPP_DEVICE_INLINE uint2 add_vectors_helper(uint2 a, uint2 b) {
uint2 ret;
@@ -78,6 +83,11 @@ MSCCLPP_DEVICE_INLINE __attribute__((unused)) uint2 add_vectors<__half>(uint2 a,
return add_vectors_helper<__half2>(a, b);
}
template <>
MSCCLPP_DEVICE_INLINE __attribute__((unused)) uint2 add_vectors<__bfloat16>(uint2 a, uint2 b) {
return add_vectors_helper<__bfloat162>(a, b);
}
template <typename T>
MSCCLPP_DEVICE_INLINE int add_vectors_helper(int a, int b) {
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
@@ -93,6 +103,11 @@ MSCCLPP_DEVICE_INLINE __attribute__((unused)) int add_vectors<__half>(int a, int
return add_vectors_helper<__half2>(a, b);
}
template <>
MSCCLPP_DEVICE_INLINE __attribute__((unused)) int add_vectors<__bfloat16>(int a, int b) {
return add_vectors_helper<__bfloat162>(a, b);
}
template <typename T>
MSCCLPP_DEVICE_INLINE uint32_t add_vectors_helper(uint32_t a, uint32_t b) {
return bit_cast<uint32_t, T>(add_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
@@ -108,6 +123,11 @@ MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__half>(uint32_t a, uint32_t b) {
return add_vectors_helper<__half2>(a, b);
}
template <>
MSCCLPP_DEVICE_INLINE uint32_t add_vectors<__bfloat16>(uint32_t a, uint32_t b) {
return add_vectors_helper<__bfloat162>(a, b);
}
} // namespace
#endif // defined(MSCCLPP_DEVICE_COMPILE)
@@ -502,6 +522,16 @@ class ExecutionKernel {
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
case DataType::BFLOAT16:
executionKernel<__bfloat16, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
);
#endif
break;
}

View File

@@ -9,6 +9,10 @@
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
using __bfloat16 = __hip_bfloat16;
using __bfloat162 = __hip_bfloat162;
#define __CUDA_BF16_TYPES_EXIST__
#else
#include <cuda_fp16.h>
@@ -19,6 +23,9 @@
#include <cuda_fp8.h>
#endif
using __bfloat16 = __nv_bfloat16;
using __bfloat162 = __nv_bfloat162;
#endif
#endif // MSCCLPP_GPU_DATA_TYPES_HPP_