diff --git a/CMakeLists.txt b/CMakeLists.txt index a95a8e53..6405511b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -108,7 +108,7 @@ target_include_directories(mscclpp_obj SYSTEM PRIVATE ${GPU_INCLUDE_DIRS} ${NUMA_INCLUDE_DIRS}) -target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads) +target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads dl) if(IBVERBS_FOUND) target_include_directories(mscclpp_obj SYSTEM PRIVATE ${IBVERBS_INCLUDE_DIRS}) target_link_libraries(mscclpp_obj PRIVATE ${IBVERBS_LIBRARIES}) diff --git a/src/include/execution_kernel.hpp b/src/include/execution_kernel.hpp index 21ea80bf..bfc2cebe 100644 --- a/src/include/execution_kernel.hpp +++ b/src/include/execution_kernel.hpp @@ -40,6 +40,16 @@ MSCCLPP_DEVICE_INLINE __half2 add_elements(__half2 a, __half2 b) { return __hadd2(a, b); } +template <> +MSCCLPP_DEVICE_INLINE __bfloat16 add_elements(__bfloat16 a, __bfloat16 b) { + return __hadd(a, b); +} + +template <> +MSCCLPP_DEVICE_INLINE __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) { + return __hadd2(a, b); +} + template MSCCLPP_DEVICE_INLINE int4 add_vectors_helper(int4 a, int4 b) { int4 ret; @@ -239,7 +249,7 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf T tmp = input[idx]; for (int index = 0; index < nSrcChannels; ++index) { size_t srcOffset = srcOffsets[index] / sizeof(T); - tmp += smChannels[srcChannelIndexes[index]].read(srcOffset + idx); + tmp = add_elements(tmp, smChannels[srcChannelIndexes[index]].read(srcOffset + idx)); } output[idx] = tmp; if (sendToRemote) { @@ -360,7 +370,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T T tmp = src[idx]; for (int index = 0; index < nOutChannels; ++index) { size_t offset = inputOffsets[index] / sizeof(T); - tmp += input[offset + idx]; + tmp = add_elements(tmp, input[offset + idx]); } dst[idx] = tmp; for (int index = 0; index < nOutChannels; ++index) {