Add min operation for allreduce (#481)

Add min operation for allreduce
This commit is contained in:
Binyang Li
2025-03-16 20:47:36 -07:00
committed by GitHub
parent 0b840baa05
commit f124dc1df9
2 changed files with 273 additions and 18 deletions

View File

@@ -18,6 +18,11 @@
#include "common.hpp"
enum Op {
SUM = 0,
MIN = 3,
};
template <typename To, typename From>
__forceinline__ __device__ To bit_cast(const From& src) {
static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast");
@@ -73,16 +78,48 @@ __forceinline__ __device__ T add_elements(T a, T b) {
return clip(a + b);
}
template <typename T>
__forceinline__ __device__ T min_elements(T a, T b) {
return (a < b ? a : b);
}
template <typename T, Op op>
__forceinline__ __device__ T cal_elements(T a, T b) {
if constexpr (op == SUM) {
return add_elements(a, b);
} else if constexpr (op == MIN) {
return min_elements(a, b);
}
return (a < b ? a : b);
}
template <>
__forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) {
return clip(__hadd2(a, b));
}
template <>
__forceinline__ __device__ __half2 min_elements(__half2 a, __half2 b) {
#if defined(__HIP_PLATFORM_AMD__)
__half2 val;
val.x = __hmin(a.x, b.x);
val.y = __hmin(a.y, b.y);
return val;
#else
return __hmin2(a, b);
#endif
}
template <>
__forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) {
return clip(__hadd2(a, b));
}
template <>
__forceinline__ __device__ __bfloat162 min_elements(__bfloat162 a, __bfloat162 b) {
return __hmin2(a, b);
}
template <typename T>
__forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
int4 ret;
@@ -93,16 +130,57 @@ __forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) {
return ret;
}
template <typename T>
__forceinline__ __device__ int4 min_vectors_helper(int4 a, int4 b) {
int4 ret;
ret.w = bit_cast<int, T>(min_elements(bit_cast<T, int>(a.w), bit_cast<T, int>(b.w)));
ret.x = bit_cast<int, T>(min_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
ret.y = bit_cast<int, T>(min_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
ret.z = bit_cast<int, T>(min_elements(bit_cast<T, int>(a.z), bit_cast<T, int>(b.z)));
return ret;
}
template <typename T, Op op>
__forceinline__ __device__ int4 cal_vectors_helper(int4 a, int4 b) {
if constexpr (op == SUM) {
return add_vectors_helper<T>(a, b);
} else if constexpr (op == MIN) {
return min_vectors_helper<T>(a, b);
}
return a;
}
template <typename T>
__forceinline__ __device__ int4 add_vectors(int4 a, int4 b) {
return add_vectors_helper<T>(a, b);
}
template <typename T>
__forceinline__ __device__ int4 min_vectors(int4 a, int4 b) {
return min_vectors_helper<T>(a, b);
}
template <typename T>
__forceinline__ __device__ int4 cal_vectors(int4 a, int4 b, Op op) {
if (op == SUM) {
return cal_vectors_helper<T, SUM>(a, b);
} else if (op == MIN) {
return cal_vectors_helper<T, MIN>(a, b);
}
// SHOULD NOT REACH HERE
return a;
}
template <>
__forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) {
return add_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ int4 min_vectors<__half>(int4 a, int4 b) {
return min_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ int4 add_vectors<__bfloat16>(int4 a, int4 b) {
return add_vectors_helper<__bfloat162>(a, b);
@@ -116,61 +194,226 @@ __forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) {
return ret;
}
template <typename T>
__forceinline__ __device__ uint2 min_vectors_helper(uint2 a, uint2 b) {
uint2 ret;
ret.x = bit_cast<int, T>(min_elements(bit_cast<T, int>(a.x), bit_cast<T, int>(b.x)));
ret.y = bit_cast<int, T>(min_elements(bit_cast<T, int>(a.y), bit_cast<T, int>(b.y)));
return ret;
}
template <typename T, Op op>
__forceinline__ __device__ uint2 cal_vectors_helper(uint2 a, uint2 b) {
if constexpr (op == SUM) {
return add_vectors_helper<T>(a, b);
} else if constexpr (op == MIN) {
return min_vectors_helper<T>(a, b);
}
return a;
}
template <typename T>
__forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) {
return add_vectors_helper<T>(a, b);
}
template <typename T>
__forceinline__ __device__ uint2 min_vectors(uint2 a, uint2 b) {
return min_vectors_helper<T>(a, b);
}
template <typename T>
__forceinline__ __device__ uint2 cal_vectors(uint2 a, uint2 b, Op op) {
if (op == SUM) {
return cal_vectors_helper<T, SUM>(a, b);
} else {
return cal_vectors_helper<T, MIN>(a, b);
}
}
template <>
__forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) {
return add_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ uint2 min_vectors<__half>(uint2 a, uint2 b) {
return min_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ uint2 cal_vectors<__half>(uint2 a, uint2 b, Op op) {
if (op == SUM) {
return cal_vectors_helper<__half2, SUM>(a, b);
} else {
return cal_vectors_helper<__half2, MIN>(a, b);
}
}
template <>
__forceinline__ __device__ uint2 add_vectors<__bfloat16>(uint2 a, uint2 b) {
return add_vectors_helper<__bfloat162>(a, b);
}
template <>
__forceinline__ __device__ uint2 min_vectors<__bfloat16>(uint2 a, uint2 b) {
return min_vectors_helper<__bfloat162>(a, b);
}
template <typename T>
__forceinline__ __device__ int add_vectors_helper(int a, int b) {
return bit_cast<int, T>(add_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
}
template <typename T>
__forceinline__ __device__ int min_vectors_helper(int a, int b) {
return bit_cast<int, T>(min_elements(bit_cast<T, int>(a), bit_cast<T, int>(b)));
}
template <typename T, Op op>
__forceinline__ __device__ int cal_vectors_helper(int a, int b) {
if constexpr (op == SUM) {
return add_vectors_helper<T>(a, b);
} else if constexpr (op == MIN) {
return min_vectors_helper<T>(a, b);
}
return a;
}
template <typename T>
__forceinline__ __device__ int add_vectors(int a, int b) {
return add_vectors_helper<T>(a, b);
}
template <typename T>
__forceinline__ __device__ int min_vectors(int a, int b) {
return min_vectors_helper<T>(a, b);
}
template <typename T>
__forceinline__ __device__ int cal_vectors(int a, int b, Op op) {
if (op == SUM) {
return cal_vectors_helper<T, SUM>(a, b);
} else {
return cal_vectors_helper<T, MIN>(a, b);
}
}
template <>
__forceinline__ __device__ int add_vectors<__half>(int a, int b) {
return add_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ int min_vectors<__half>(int a, int b) {
return min_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ int cal_vectors<__half>(int a, int b, Op op) {
if (op == SUM) {
return cal_vectors_helper<__half2, SUM>(a, b);
} else {
return cal_vectors_helper<__half2, MIN>(a, b);
}
}
template <>
__forceinline__ __device__ int add_vectors<__bfloat16>(int a, int b) {
return add_vectors_helper<__bfloat162>(a, b);
}
template <>
__forceinline__ __device__ int min_vectors<__bfloat16>(int a, int b) {
return min_vectors_helper<__bfloat162>(a, b);
}
template <>
__forceinline__ __device__ int cal_vectors<__bfloat16>(int a, int b, Op op) {
if (op == SUM) {
return cal_vectors_helper<__bfloat162, SUM>(a, b);
} else {
return cal_vectors_helper<__bfloat162, MIN>(a, b);
}
}
template <typename T>
__forceinline__ __device__ uint32_t add_vectors_helper(uint32_t a, uint32_t b) {
return bit_cast<uint32_t, T>(add_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
}
template <typename T>
__forceinline__ __device__ uint32_t min_vectors_helper(uint32_t a, uint32_t b) {
return bit_cast<uint32_t, T>(min_elements(bit_cast<T, uint32_t>(a), bit_cast<T, uint32_t>(b)));
}
template <typename T, Op op>
__forceinline__ __device__ uint32_t cal_vectors_helper(uint32_t a, uint32_t b) {
if constexpr (op == SUM) {
return add_vectors_helper<T>(a, b);
} else if constexpr (op == MIN) {
return min_vectors_helper<T>(a, b);
}
return a;
}
template <typename T>
__forceinline__ __device__ uint32_t add_vectors(uint32_t a, uint32_t b) {
return add_vectors_helper<T>(a, b);
}
template <typename T>
__forceinline__ __device__ uint32_t min_vectors(uint32_t a, uint32_t b) {
return min_vectors_helper<T>(a, b);
}
template <typename T>
__forceinline__ __device__ uint32_t cal_vectors(uint32_t a, uint32_t b, Op op) {
if (op == SUM) {
return cal_vectors_helper<T, SUM>(a, b);
} else {
return cal_vectors_helper<T, MIN>(a, b);
}
}
template <>
__forceinline__ __device__ uint32_t add_vectors<__half>(uint32_t a, uint32_t b) {
return add_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ uint32_t min_vectors<__half>(uint32_t a, uint32_t b) {
return min_vectors_helper<__half2>(a, b);
}
template <>
__forceinline__ __device__ uint32_t cal_vectors<__half>(uint32_t a, uint32_t b, Op op) {
if (op == SUM) {
return cal_vectors_helper<__half2, SUM>(a, b);
} else {
return cal_vectors_helper<__half2, MIN>(a, b);
}
}
template <>
__forceinline__ __device__ uint32_t add_vectors<__bfloat16>(uint32_t a, uint32_t b) {
return add_vectors_helper<__bfloat162>(a, b);
}
template <>
__forceinline__ __device__ uint32_t min_vectors<__bfloat16>(uint32_t a, uint32_t b) {
return min_vectors_helper<__bfloat162>(a, b);
}
template <>
__forceinline__ __device__ uint32_t cal_vectors<__bfloat16>(uint32_t a, uint32_t b, Op op) {
if (op == SUM) {
return cal_vectors_helper<__bfloat162, SUM>(a, b);
} else {
return cal_vectors_helper<__bfloat162, MIN>(a, b);
}
}
template <typename T>
__forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem, int blockId, int nBlocks) {
size_t nInt4 = nElem / 4;
@@ -198,7 +441,7 @@ template <typename T>
__global__ void __launch_bounds__(32, 1)
allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, uint32_t flag) {
Op op, size_t nelems, uint32_t flag) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
@@ -226,14 +469,13 @@ __global__ void __launch_bounds__(32, 1)
// step 2: Reduce Data
for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) {
uint32_t data = 0;
uint32_t data = src[idx];
for (int index = 0; index < nPeers; index++) {
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nelems;
uint32_t val = dstPkt[idx].read(flag, -1);
data = add_vectors<T>(val, data);
data = cal_vectors<T>(val, data, op);
}
data = add_vectors<T>(data, src[idx]);
dst[idx] = data;
}
}
@@ -241,7 +483,7 @@ __global__ void __launch_bounds__(32, 1)
template <typename T>
__global__ void __launch_bounds__(1024, 1)
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, Op op,
size_t nelems, uint32_t flag
#if defined(ENABLE_NPKIT)
,
@@ -321,8 +563,8 @@ __global__ void __launch_bounds__(1024, 1)
const int remoteRank = index < rank ? index : index + 1;
mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank;
uint2 val = dstPkt[idx].read(flag);
data.x = add_vectors<T>(val.x, data.x);
data.y = add_vectors<T>(val.y, data.y);
data.x = cal_vectors<T>(val.x, data.x, op);
data.y = cal_vectors<T>(val.y, data.y, op);
}
dst[idx].x = data.x;
@@ -499,7 +741,7 @@ template <typename T>
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryOutChannels, size_t channelInOffset,
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, cudaStream_t stream) {
Op op, size_t nelems, cudaStream_t stream) {
static uint32_t flag = 1;
if (sizeof(T) * nelems < worldSize * sizeof(int)) {
@@ -507,7 +749,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
int nThreadsPerBlock = 32;
allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, memoryChannels,
channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, flag++);
nRanksPerNode, worldSize, op, nelems, flag++);
} else if (sizeof(T) * nelems <= (1 << 20)) {
int nBlocks = 28;
int nThreadsPerBlock = 1024;
@@ -519,11 +761,11 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent);
allreduce7<<<nBlocks, nThreadsPerBlock, NpkitSharedMemSize, stream>>>(
buff, scratch, resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode,
worldSize, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
worldSize, op, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, memoryChannels, channelInOffset,
channelScratchOffset, rank, nRanksPerNode, worldSize, nelems,
flag++);
channelScratchOffset, rank, nRanksPerNode, worldSize, op,
nelems, flag++);
#endif
} else {
int nBlocks = 35;

View File

@@ -138,6 +138,18 @@ static mscclpp::Transport getTransport(int, int) {
return mscclpp::Transport::CudaIpc;
}
static Op getReduceOp(ncclRedOp_t op) {
switch (op) {
case ncclSum:
return SUM;
case ncclMin:
return MIN;
default:
WARN("op is invalid, op: %d", op);
throw mscclpp::Error("Invalid operation", mscclpp::ErrorCode::InternalError);
}
}
static std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<mscclpp::Communicator> comm, int rank,
void* buff, size_t bytes,
mscclpp::TransportFlags transport) {
@@ -268,27 +280,28 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
memoryOutChannels = recvIt->second.memoryChannelDeviceHandles.get();
}
Op reduceOp = getReduceOp(op);
switch (datatype) {
case ncclFloat16:
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, memoryChannels,
memoryOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
comm->comm->bootstrap()->getNranks(), reduceOp, count, stream));
break;
case ncclFloat32:
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, memoryChannels,
memoryOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), reduceOp, count, stream));
break;
case ncclBfloat16:
CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff,
memoryChannels, memoryOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
comm->comm->bootstrap()->getNranks(), reduceOp, count, stream));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, memoryChannels,
memoryOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), reduceOp, count, stream));
break;
default:
WARN("datatype is invalid, datatype: %d", datatype);
@@ -955,10 +968,10 @@ ncclResult_t ncclMemAlloc(void** ptr, size_t size) {
return ncclInternalError;
}
} catch (const mscclpp::CudaError& e) {
INFO(MSCCLPP_ALLOC, "Cuda error: %s", e.what());
WARN("Cuda error: %s", e.what());
return ncclUnhandledCudaError;
} catch (const mscclpp::CuError& e) {
INFO(MSCCLPP_ALLOC, "Cu error: %s", e.what());
WARN("Cu error: %s", e.what());
return ncclUnhandledCudaError;
} catch (const mscclpp::BaseError& e) {
WARN("Base error: %s", e.what());