mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 09:17:06 +00:00
Fix for ROCm 6.0 (#347)
This commit is contained in:
@@ -108,7 +108,7 @@ target_include_directories(mscclpp_obj
|
||||
SYSTEM PRIVATE
|
||||
${GPU_INCLUDE_DIRS}
|
||||
${NUMA_INCLUDE_DIRS})
|
||||
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads)
|
||||
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} nlohmann_json::nlohmann_json Threads::Threads dl)
|
||||
if(IBVERBS_FOUND)
|
||||
target_include_directories(mscclpp_obj SYSTEM PRIVATE ${IBVERBS_INCLUDE_DIRS})
|
||||
target_link_libraries(mscclpp_obj PRIVATE ${IBVERBS_LIBRARIES})
|
||||
|
||||
@@ -40,6 +40,16 @@ MSCCLPP_DEVICE_INLINE __half2 add_elements(__half2 a, __half2 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE __bfloat16 add_elements(__bfloat16 a, __bfloat16 b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
MSCCLPP_DEVICE_INLINE __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
|
||||
return __hadd2(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
MSCCLPP_DEVICE_INLINE int4 add_vectors_helper(int4 a, int4 b) {
|
||||
int4 ret;
|
||||
@@ -239,7 +249,7 @@ MSCCLPP_DEVICE_INLINE void handleReadReduceCopySend(T* output, uint32_t outputOf
|
||||
T tmp = input[idx];
|
||||
for (int index = 0; index < nSrcChannels; ++index) {
|
||||
size_t srcOffset = srcOffsets[index] / sizeof(T);
|
||||
tmp += smChannels[srcChannelIndexes[index]].read<T>(srcOffset + idx);
|
||||
tmp = add_elements(tmp, smChannels[srcChannelIndexes[index]].read<T>(srcOffset + idx));
|
||||
}
|
||||
output[idx] = tmp;
|
||||
if (sendToRemote) {
|
||||
@@ -360,7 +370,7 @@ MSCCLPP_DEVICE_INLINE void handleReduceSend(T* dst, uint32_t dstOffsetByBytes, T
|
||||
T tmp = src[idx];
|
||||
for (int index = 0; index < nOutChannels; ++index) {
|
||||
size_t offset = inputOffsets[index] / sizeof(T);
|
||||
tmp += input[offset + idx];
|
||||
tmp = add_elements(tmp, input[offset + idx]);
|
||||
}
|
||||
dst[idx] = tmp;
|
||||
for (int index = 0; index < nOutChannels; ++index) {
|
||||
|
||||
Reference in New Issue
Block a user