Fix for ROCm 6.0 (#347)

This commit is contained in:
Changho Hwang
2024-09-01 20:22:33 -07:00
committed by GitHub
parent 4eca6f1e95
commit 72b99a4229
2 changed files with 13 additions and 3 deletions

View File

@@ -108,7 +108,7 @@ target_include_directories(mscclpp_obj
SYSTEM PRIVATE
${GPU_INCLUDE_DIRS}
${NUMA_INCLUDE_DIRS})
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads)
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads dl)
if(IBVERBS_FOUND)
target_include_directories(mscclpp_obj SYSTEM PRIVATE ${IBVERBS_INCLUDE_DIRS})
target_link_libraries(mscclpp_obj PRIVATE ${IBVERBS_LIBRARIES})

View File

@@ -40,6 +40,16 @@ MSCCLPP_DEVICE_INLINE __half2 add_elements(__half2 a, __half2 b) {
return __hadd2(a, b);
}
template <>
MSCCLPP_DEVICE_INLINE __bfloat16 add_elements(__bfloat16 a, __bfloat16 b) {
return __hadd(a, b);
}
template <>
MSCCLPP_DEVICE_INLINE __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
return __hadd2(a, b);
}
template <typename T>
MSCCLPP_DEVICE_INLINE int4 add_vectors_helper(int4 a, int4 b) {
int4 ret;
@@ -239,7 +249,7 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf
T tmp = input[idx];
for (int index = 0; index < nSrcChannels; ++index) {
size_t srcOffset = srcOffsets[index] / sizeof(T);
tmp += smChannels[srcChannelIndexes[index]].read<T>(srcOffset + idx);
tmp = add_elements(tmp, smChannels[srcChannelIndexes[index]].read<T>(srcOffset + idx));
}
output[idx] = tmp;
if (sendToRemote) {
@@ -360,7 +370,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
T tmp = src[idx];
for (int index = 0; index < nOutChannels; ++index) {
size_t offset = inputOffsets[index] / sizeof(T);
tmp += input[offset + idx];
tmp = add_elements(tmp, input[offset + idx]);
}
dst[idx] = tmp;
for (int index = 0; index < nOutChannels; ++index) {