mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-12 01:10:22 +00:00
Fix & improve perf for ROCm (#232)
Co-authored-by: Binyang Li <binyli@microsoft.com>
This commit is contained in:
@@ -108,6 +108,7 @@ include(${PROJECT_SOURCE_DIR}/cmake/AddFormatTargets.cmake)
|
||||
# Find ibverbs and libnuma
|
||||
find_package(IBVerbs REQUIRED)
|
||||
find_package(NUMA REQUIRED)
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
add_library(mscclpp_obj OBJECT)
|
||||
target_include_directories(mscclpp_obj
|
||||
@@ -115,7 +116,7 @@ target_include_directories(mscclpp_obj
|
||||
${GPU_INCLUDE_DIRS}
|
||||
${IBVERBS_INCLUDE_DIRS}
|
||||
${NUMA_INCLUDE_DIRS})
|
||||
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES})
|
||||
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads)
|
||||
set_target_properties(mscclpp_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
|
||||
if(USE_CUDA)
|
||||
target_compile_definitions(mscclpp_obj PRIVATE USE_CUDA)
|
||||
|
||||
@@ -4,20 +4,20 @@
|
||||
#ifndef MSCCLPP_DEVICE_HPP_
|
||||
#define MSCCLPP_DEVICE_HPP_
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif // defined(__HIP_PLATFORM_AMD__)
|
||||
#endif // defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
|
||||
#if (defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__))
|
||||
|
||||
#define MSCCLPP_DEVICE_COMPILE
|
||||
#define MSCCLPP_DEVICE_INLINE __forceinline__ __device__
|
||||
#define MSCCLPP_HOST_DEVICE_INLINE __forceinline__ __host__ __device__
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
#define MSCCLPP_DEVICE_HIP
|
||||
#else // !defined(__HIP_PLATFORM_AMD__)
|
||||
#else // !(defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1))
|
||||
#define MSCCLPP_DEVICE_CUDA
|
||||
#endif // !defined(__HIP_PLATFORM_AMD__)
|
||||
#endif // !(defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1))
|
||||
|
||||
#else // !(defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__))
|
||||
|
||||
|
||||
@@ -70,9 +70,9 @@ struct FifoDeviceHandle {
|
||||
#if defined(MSCCLPP_DEVICE_CUDA)
|
||||
asm volatile("st.global.relaxed.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
|
||||
#else // !defined(MSCCLPP_DEVICE_CUDA)
|
||||
// TODO: both atomic and clang built-ins are buggy here
|
||||
triggerPtr->fst = trigger.fst;
|
||||
triggerPtr->snd = trigger.snd;
|
||||
// store snd no later than fst.
|
||||
atomicStore(&(triggerPtr->snd), trigger.snd, memoryOrderRelaxed);
|
||||
atomicStore(&(triggerPtr->fst), trigger.fst, memoryOrderRelaxed);
|
||||
#endif // !defined(MSCCLPP_DEVICE_CUDA)
|
||||
|
||||
return curFifoHead;
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#ifndef MSCCLPP_GPU_HPP_
|
||||
#define MSCCLPP_GPU_HPP_
|
||||
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ union alignas(16) LLPacket {
|
||||
#else // !defined(MSCCLPP_DEVICE_CUDA)
|
||||
uint4 reg = make_uint4(val1, flag, val2, flag);
|
||||
ulonglong2* p = reinterpret_cast<ulonglong2*>(®);
|
||||
// TODO: clang built-ins are buggy here
|
||||
atomicStore(&(raw_.x), p->x, memoryOrderRelaxed);
|
||||
atomicStore(&(raw_.y), p->y, memoryOrderRelaxed);
|
||||
#endif
|
||||
@@ -65,7 +64,6 @@ union alignas(16) LLPacket {
|
||||
return (flag1 != flag) || (flag2 != flag);
|
||||
#else // !defined(MSCCLPP_DEVICE_CUDA)
|
||||
ulonglong2 reg;
|
||||
// TODO: clang built-ins are buggy here
|
||||
reg.x = atomicLoad(&(raw_.x), memoryOrderRelaxed);
|
||||
reg.y = atomicLoad(&(raw_.y), memoryOrderRelaxed);
|
||||
uint4* ptr = reinterpret_cast<uint4*>(®);
|
||||
|
||||
@@ -17,7 +17,7 @@ struct Timer {
|
||||
|
||||
~Timer();
|
||||
|
||||
/// Returns the elapsed time in milliseconds.
|
||||
/// Returns the elapsed time in microseconds.
|
||||
int64_t elapsed() const;
|
||||
|
||||
void set(int timeout);
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
find_package(MPI)
|
||||
|
||||
set(TEST_LIBS_COMMON mscclpp ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES})
|
||||
set(TEST_LIBS_COMMON mscclpp ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads)
|
||||
set(TEST_LIBS_GTEST GTest::gtest_main GTest::gmock_main)
|
||||
set(TEST_INC_COMMON PRIVATE ${PROJECT_SOURCE_DIR}/include ${GPU_INCLUDE_DIRS})
|
||||
set(TEST_INC_INTERNAL PRIVATE ${PROJECT_SOURCE_DIR}/src/include)
|
||||
|
||||
@@ -74,7 +74,7 @@ __device__ void localAllGather(DeviceHandle<mscclpp::SimpleProxyChannel> proxyCh
|
||||
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) {
|
||||
if ((threadIdx.x % 32) == 0) proxyChan.wait();
|
||||
}
|
||||
#if defined(__HIP_PLATFORM_AMD__)
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
// NOTE: we actually need a group barrier here for better performance, but __syncthreads() is still correct.
|
||||
__syncthreads();
|
||||
#else
|
||||
|
||||
@@ -371,6 +371,68 @@ __device__ void localReduceScatterSm2(int* buff, int rank, int nRanksPerNode, si
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void localReduceScatterSm3(int* buff, int rank, int nRanksPerNode, size_t chunkSize, size_t nelems,
|
||||
int nBlocks) {
|
||||
if (nRanksPerNode == 1) return;
|
||||
if ((int)blockIdx.x >= nBlocks) return;
|
||||
const int nPeer = nRanksPerNode - 1;
|
||||
DeviceHandle<mscclpp::SmChannel>* smChans = constSmOutOfPlaceGetChans;
|
||||
|
||||
const size_t localRankIndexInNode = rank % nRanksPerNode;
|
||||
const size_t indexOffset = localRankIndexInNode * chunkSize;
|
||||
const size_t indexOffset4 = indexOffset / 4;
|
||||
|
||||
int4* buff4 = (int4*)buff;
|
||||
|
||||
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (tid < nPeer) {
|
||||
smChans[tid].signal();
|
||||
}
|
||||
const int waitStart = nBlocks * blockDim.x - nPeer;
|
||||
if (tid >= waitStart && tid < (int)(nBlocks * blockDim.x)) {
|
||||
smChans[tid - waitStart].wait();
|
||||
}
|
||||
reduceScatterDeviceSyncer.sync(nBlocks);
|
||||
|
||||
const size_t nInt4 = nelems / 4;
|
||||
|
||||
size_t base = 0;
|
||||
const size_t unitNInt4 = blockDim.x * nBlocks;
|
||||
for (; base + unitNInt4 < nInt4; base += unitNInt4) {
|
||||
for (int index = 0; index < nPeer; ++index) {
|
||||
int4 val;
|
||||
int peerIdx = (index + localRankIndexInNode) % nPeer;
|
||||
for (size_t idx = base + threadIdx.x + blockIdx.x * blockDim.x; idx < base + unitNInt4;
|
||||
idx += blockDim.x * nBlocks) {
|
||||
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
|
||||
buff4[indexOffset4 + idx].w += val.w;
|
||||
buff4[indexOffset4 + idx].x += val.x;
|
||||
buff4[indexOffset4 + idx].y += val.y;
|
||||
buff4[indexOffset4 + idx].z += val.z;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int index = 0; index < nPeer; ++index) {
|
||||
int4 val;
|
||||
int peerIdx = (index + localRankIndexInNode) % nPeer;
|
||||
for (size_t idx = base + threadIdx.x + blockIdx.x * blockDim.x; idx < nInt4; idx += blockDim.x * nBlocks) {
|
||||
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
|
||||
buff4[indexOffset4 + idx].w += val.w;
|
||||
buff4[indexOffset4 + idx].x += val.x;
|
||||
buff4[indexOffset4 + idx].y += val.y;
|
||||
buff4[indexOffset4 + idx].z += val.z;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t nLastInts = nelems % 4;
|
||||
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
|
||||
for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nLastInts; idx += blockDim.x * nBlocks) {
|
||||
int val = smChans[(localRankIndexInNode + peerIdx) % nPeer].read<int>(indexOffset + nInt4 * 4 + idx);
|
||||
buff[indexOffset + nInt4 * 4 + idx] += val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void reduceScatterSm(int* buff, int* scratch, int rank, int nRanksPerNode, int worldSize,
|
||||
size_t nelems // must be divisible by 3
|
||||
) {
|
||||
@@ -520,6 +582,39 @@ __device__ void localRingAllGatherSm(int rank, int nRanksPerNode, uint64_t size,
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void localRingAllGatherSm2(size_t rank, size_t nRanksPerNode, size_t size, size_t nBlocks) {
|
||||
if (nRanksPerNode == 1) return;
|
||||
if (blockIdx.x >= nBlocks) return;
|
||||
|
||||
size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
const size_t nPeer = nRanksPerNode - 1;
|
||||
|
||||
if (tid < nPeer) {
|
||||
constSmInPlaceChans[tid].signal();
|
||||
}
|
||||
size_t waitStart = nBlocks * blockDim.x - nPeer;
|
||||
if (tid >= waitStart && tid < nBlocks * blockDim.x) {
|
||||
constSmInPlaceChans[tid - waitStart].wait();
|
||||
}
|
||||
allGatherDeviceSyncer.sync(nBlocks);
|
||||
const size_t unitSize = 16 * blockDim.x * nBlocks;
|
||||
size_t base = 0;
|
||||
for (; base + unitSize < size; base += unitSize) {
|
||||
for (size_t i = 0; i < nPeer; ++i) {
|
||||
size_t peerIdx = (i + rank) % nPeer;
|
||||
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
||||
size_t offset = size * remoteRankLocalIndex + base;
|
||||
constSmInPlaceChans[peerIdx].get(offset, unitSize, tid, blockDim.x * nBlocks);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < nPeer; ++i) {
|
||||
size_t peerIdx = (i + rank) % nPeer;
|
||||
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
|
||||
size_t offset = size * remoteRankLocalIndex + base;
|
||||
constSmInPlaceChans[peerIdx].get(offset, size - base, tid, blockDim.x * nBlocks);
|
||||
}
|
||||
}
|
||||
|
||||
// This is an allgather4 equivalent
|
||||
__device__ void allGatherSm(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU) {
|
||||
// this allgather is a pipelined and hierarchical one and only works for two nodes
|
||||
@@ -861,9 +956,15 @@ __global__ void allreduce4(int* buff, int* scratch, int rank, int nRanksPerNode,
|
||||
}
|
||||
|
||||
__global__ void allreduce5(int* buff, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
|
||||
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
|
||||
localReduceScatterSm3(buff, rank, nRanksPerNode, nelems / worldSize, nelems / worldSize, gridDim.x);
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
localRingAllGatherSm2(rank, nRanksPerNode, nelems / worldSize * sizeof(int), gridDim.x);
|
||||
#else
|
||||
localReduceScatterSm2(buff, rank, nRanksPerNode, nelems / worldSize, nelems / worldSize, gridDim.x);
|
||||
deviceSyncer.sync(gridDim.x);
|
||||
localRingAllGatherSm(rank, nRanksPerNode, nelems / worldSize * sizeof(int), gridDim.x);
|
||||
#endif
|
||||
}
|
||||
|
||||
__global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank, int nRanksPerNode, int worldSize,
|
||||
|
||||
Reference in New Issue
Block a user