diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index 3861fa5a..d9f12ce3 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -18,6 +18,11 @@ #include "common.hpp" +enum Op { + SUM = 0, + MIN = 3, +}; + template __forceinline__ __device__ To bit_cast(const From& src) { static_assert(sizeof(To) == sizeof(From), "Size mismatch for bit_cast"); @@ -73,16 +78,48 @@ __forceinline__ __device__ T add_elements(T a, T b) { return clip(a + b); } +template +__forceinline__ __device__ T min_elements(T a, T b) { + return (a < b ? a : b); +} + +template +__forceinline__ __device__ T cal_elements(T a, T b) { + if constexpr (op == SUM) { + return add_elements(a, b); + } else if constexpr (op == MIN) { + return min_elements(a, b); + } + return (a < b ? a : b); +} + template <> __forceinline__ __device__ __half2 add_elements(__half2 a, __half2 b) { return clip(__hadd2(a, b)); } +template <> +__forceinline__ __device__ __half2 min_elements(__half2 a, __half2 b) { +#if defined(__HIP_PLATFORM_AMD__) + __half2 val; + val.x = __hmin(a.x, b.x); + val.y = __hmin(a.y, b.y); + return val; +#else + return __hmin2(a, b); +#endif +} + template <> __forceinline__ __device__ __bfloat162 add_elements(__bfloat162 a, __bfloat162 b) { return clip(__hadd2(a, b)); } +template <> +__forceinline__ __device__ __bfloat162 min_elements(__bfloat162 a, __bfloat162 b) { + return __hmin2(a, b); +} + template __forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) { int4 ret; @@ -93,16 +130,57 @@ __forceinline__ __device__ int4 add_vectors_helper(int4 a, int4 b) { return ret; } +template +__forceinline__ __device__ int4 min_vectors_helper(int4 a, int4 b) { + int4 ret; + ret.w = bit_cast(min_elements(bit_cast(a.w), bit_cast(b.w))); + ret.x = bit_cast(min_elements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(min_elements(bit_cast(a.y), bit_cast(b.y))); + ret.z = bit_cast(min_elements(bit_cast(a.z), bit_cast(b.z))); + return ret; +} + +template +__forceinline__ __device__ int4 cal_vectors_helper(int4 a, int4 b) { + if constexpr (op == SUM) { + return add_vectors_helper(a, b); + } else if constexpr (op == MIN) { + return min_vectors_helper(a, b); + } + return a; +} + template __forceinline__ __device__ int4 add_vectors(int4 a, int4 b) { return add_vectors_helper(a, b); } +template +__forceinline__ __device__ int4 min_vectors(int4 a, int4 b) { + return min_vectors_helper(a, b); +} + +template +__forceinline__ __device__ int4 cal_vectors(int4 a, int4 b, Op op) { + if (op == SUM) { + return cal_vectors_helper(a, b); + } else if (op == MIN) { + return cal_vectors_helper(a, b); + } + // SHOULD NOT REACH HERE + return a; +} + template <> __forceinline__ __device__ int4 add_vectors<__half>(int4 a, int4 b) { return add_vectors_helper<__half2>(a, b); } +template <> +__forceinline__ __device__ int4 min_vectors<__half>(int4 a, int4 b) { + return min_vectors_helper<__half2>(a, b); +} + template <> __forceinline__ __device__ int4 add_vectors<__bfloat16>(int4 a, int4 b) { return add_vectors_helper<__bfloat162>(a, b); @@ -116,61 +194,226 @@ __forceinline__ __device__ uint2 add_vectors_helper(uint2 a, uint2 b) { return ret; } +template +__forceinline__ __device__ uint2 min_vectors_helper(uint2 a, uint2 b) { + uint2 ret; + ret.x = bit_cast(min_elements(bit_cast(a.x), bit_cast(b.x))); + ret.y = bit_cast(min_elements(bit_cast(a.y), bit_cast(b.y))); + return ret; +} + +template +__forceinline__ __device__ uint2 cal_vectors_helper(uint2 a, uint2 b) { + if constexpr (op == SUM) { + return add_vectors_helper(a, b); + } else if constexpr (op == MIN) { + return min_vectors_helper(a, b); + } + return a; +} + template __forceinline__ __device__ uint2 add_vectors(uint2 a, uint2 b) { return add_vectors_helper(a, b); } +template +__forceinline__ __device__ uint2 min_vectors(uint2 a, uint2 b) { + return min_vectors_helper(a, b); +} + +template +__forceinline__ __device__ uint2 cal_vectors(uint2 a, uint2 b, Op op) { + if (op == SUM) { + return cal_vectors_helper(a, b); + } else { + return cal_vectors_helper(a, b); + } +} + template <> __forceinline__ __device__ uint2 add_vectors<__half>(uint2 a, uint2 b) { return add_vectors_helper<__half2>(a, b); } +template <> +__forceinline__ __device__ uint2 min_vectors<__half>(uint2 a, uint2 b) { + return min_vectors_helper<__half2>(a, b); +} + +template <> +__forceinline__ __device__ uint2 cal_vectors<__half>(uint2 a, uint2 b, Op op) { + if (op == SUM) { + return cal_vectors_helper<__half2, SUM>(a, b); + } else { + return cal_vectors_helper<__half2, MIN>(a, b); + } +} + template <> __forceinline__ __device__ uint2 add_vectors<__bfloat16>(uint2 a, uint2 b) { return add_vectors_helper<__bfloat162>(a, b); } +template <> +__forceinline__ __device__ uint2 min_vectors<__bfloat16>(uint2 a, uint2 b) { + return min_vectors_helper<__bfloat162>(a, b); +} + template __forceinline__ __device__ int add_vectors_helper(int a, int b) { return bit_cast(add_elements(bit_cast(a), bit_cast(b))); } +template +__forceinline__ __device__ int min_vectors_helper(int a, int b) { + return bit_cast(min_elements(bit_cast(a), bit_cast(b))); +} + +template +__forceinline__ __device__ int cal_vectors_helper(int a, int b) { + if constexpr (op == SUM) { + return add_vectors_helper(a, b); + } else if constexpr (op == MIN) { + return min_vectors_helper(a, b); + } + return a; +} + template __forceinline__ __device__ int add_vectors(int a, int b) { return add_vectors_helper(a, b); } +template +__forceinline__ __device__ int min_vectors(int a, int b) { + return min_vectors_helper(a, b); +} + +template +__forceinline__ __device__ int cal_vectors(int a, int b, Op op) { + if (op == SUM) { + return cal_vectors_helper(a, b); + } else { + return cal_vectors_helper(a, b); + } +} + template <> __forceinline__ __device__ int add_vectors<__half>(int a, int b) { return add_vectors_helper<__half2>(a, b); } +template <> +__forceinline__ __device__ int min_vectors<__half>(int a, int b) { + return min_vectors_helper<__half2>(a, b); +} + +template <> +__forceinline__ __device__ int cal_vectors<__half>(int a, int b, Op op) { + if (op == SUM) { + return cal_vectors_helper<__half2, SUM>(a, b); + } else { + return cal_vectors_helper<__half2, MIN>(a, b); + } +} + template <> __forceinline__ __device__ int add_vectors<__bfloat16>(int a, int b) { return add_vectors_helper<__bfloat162>(a, b); } +template <> +__forceinline__ __device__ int min_vectors<__bfloat16>(int a, int b) { + return min_vectors_helper<__bfloat162>(a, b); +} + +template <> +__forceinline__ __device__ int cal_vectors<__bfloat16>(int a, int b, Op op) { + if (op == SUM) { + return cal_vectors_helper<__bfloat162, SUM>(a, b); + } else { + return cal_vectors_helper<__bfloat162, MIN>(a, b); + } +} + template __forceinline__ __device__ uint32_t add_vectors_helper(uint32_t a, uint32_t b) { return bit_cast(add_elements(bit_cast(a), bit_cast(b))); } +template +__forceinline__ __device__ uint32_t min_vectors_helper(uint32_t a, uint32_t b) { + return bit_cast(min_elements(bit_cast(a), bit_cast(b))); +} + +template +__forceinline__ __device__ uint32_t cal_vectors_helper(uint32_t a, uint32_t b) { + if constexpr (op == SUM) { + return add_vectors_helper(a, b); + } else if constexpr (op == MIN) { + return min_vectors_helper(a, b); + } + return a; +} + template __forceinline__ __device__ uint32_t add_vectors(uint32_t a, uint32_t b) { return add_vectors_helper(a, b); } +template +__forceinline__ __device__ uint32_t min_vectors(uint32_t a, uint32_t b) { + return min_vectors_helper(a, b); +} + +template +__forceinline__ __device__ uint32_t cal_vectors(uint32_t a, uint32_t b, Op op) { + if (op == SUM) { + return cal_vectors_helper(a, b); + } else { + return cal_vectors_helper(a, b); + } +} + template <> __forceinline__ __device__ uint32_t add_vectors<__half>(uint32_t a, uint32_t b) { return add_vectors_helper<__half2>(a, b); } +template <> +__forceinline__ __device__ uint32_t min_vectors<__half>(uint32_t a, uint32_t b) { + return min_vectors_helper<__half2>(a, b); +} + +template <> +__forceinline__ __device__ uint32_t cal_vectors<__half>(uint32_t a, uint32_t b, Op op) { + if (op == SUM) { + return cal_vectors_helper<__half2, SUM>(a, b); + } else { + return cal_vectors_helper<__half2, MIN>(a, b); + } +} + template <> __forceinline__ __device__ uint32_t add_vectors<__bfloat16>(uint32_t a, uint32_t b) { return add_vectors_helper<__bfloat162>(a, b); } +template <> +__forceinline__ __device__ uint32_t min_vectors<__bfloat16>(uint32_t a, uint32_t b) { + return min_vectors_helper<__bfloat162>(a, b); +} + +template <> +__forceinline__ __device__ uint32_t cal_vectors<__bfloat16>(uint32_t a, uint32_t b, Op op) { + if (op == SUM) { + return cal_vectors_helper<__bfloat162, SUM>(a, b); + } else { + return cal_vectors_helper<__bfloat162, MIN>(a, b); + } +} + template __forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem, int blockId, int nBlocks) { size_t nInt4 = nElem / 4; @@ -198,7 +441,7 @@ template __global__ void __launch_bounds__(32, 1) allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, uint32_t flag) { + Op op, size_t nelems, uint32_t flag) { // This version of allreduce only works for single nodes if (worldSize != nRanksPerNode) return; if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); @@ -226,14 +469,13 @@ __global__ void __launch_bounds__(32, 1) // step 2: Reduce Data for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) { - uint32_t data = 0; + uint32_t data = src[idx]; for (int index = 0; index < nPeers; index++) { const int remoteRank = index < rank ? index : index + 1; mscclpp::LL8Packet* dstPkt = (mscclpp::LL8Packet*)scratchBuff + remoteRank * nelems; uint32_t val = dstPkt[idx].read(flag, -1); - data = add_vectors(val, data); + data = cal_vectors(val, data, op); } - data = add_vectors(data, src[idx]); dst[idx] = data; } } @@ -241,7 +483,7 @@ __global__ void __launch_bounds__(32, 1) template __global__ void __launch_bounds__(1024, 1) allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, - size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, + size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, Op op, size_t nelems, uint32_t flag #if defined(ENABLE_NPKIT) , @@ -321,8 +563,8 @@ __global__ void __launch_bounds__(1024, 1) const int remoteRank = index < rank ? index : index + 1; mscclpp::LLPacket* dstPkt = (mscclpp::LLPacket*)scratchBuff + remoteRank * nPktsPerRank; uint2 val = dstPkt[idx].read(flag); - data.x = add_vectors(val.x, data.x); - data.y = add_vectors(val.y, data.y); + data.x = cal_vectors(val.x, data.x, op); + data.y = cal_vectors(val.y, data.y, op); } dst[idx].x = data.x; @@ -499,7 +741,7 @@ template cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, mscclpp::DeviceHandle* memoryOutChannels, size_t channelInOffset, size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, cudaStream_t stream) { + Op op, size_t nelems, cudaStream_t stream) { static uint32_t flag = 1; if (sizeof(T) * nelems < worldSize * sizeof(int)) { @@ -507,7 +749,7 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< int nThreadsPerBlock = 32; allreduceAllToAll<<>>(buff, scratch, resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, - nRanksPerNode, worldSize, nelems, flag++); + nRanksPerNode, worldSize, op, nelems, flag++); } else if (sizeof(T) * nelems <= (1 << 20)) { int nBlocks = 28; int nThreadsPerBlock = 1024; @@ -519,11 +761,11 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle< size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent); allreduce7<<>>( buff, scratch, resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode, - worldSize, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); + worldSize, op, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); #else allreduce7<<>>(buff, scratch, resultBuff, memoryChannels, channelInOffset, - channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, - flag++); + channelScratchOffset, rank, nRanksPerNode, worldSize, op, + nelems, flag++); #endif } else { int nBlocks = 35; diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 4528200c..4f35fdab 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -138,6 +138,18 @@ static mscclpp::Transport getTransport(int, int) { return mscclpp::Transport::CudaIpc; } +static Op getReduceOp(ncclRedOp_t op) { + switch (op) { + case ncclSum: + return SUM; + case ncclMin: + return MIN; + default: + WARN("op is invalid, op: %d", op); + throw mscclpp::Error("Invalid operation", mscclpp::ErrorCode::InternalError); + } +} + static std::vector setupRemoteMemories(std::shared_ptr comm, int rank, void* buff, size_t bytes, mscclpp::TransportFlags transport) { @@ -268,27 +280,28 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, memoryOutChannels = recvIt->second.memoryChannelDeviceHandles.get(); } + Op reduceOp = getReduceOp(op); switch (datatype) { case ncclFloat16: CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, memoryChannels, memoryOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); + comm->comm->bootstrap()->getNranks(), reduceOp, count, stream)); break; case ncclFloat32: CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, memoryChannels, memoryOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), - NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), reduceOp, count, stream)); break; case ncclBfloat16: CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff, memoryChannels, memoryOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream)); + comm->comm->bootstrap()->getNranks(), reduceOp, count, stream)); break; case ncclInt32: case ncclUint32: CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, memoryChannels, memoryOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), - NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream)); + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), reduceOp, count, stream)); break; default: WARN("datatype is invalid, datatype: %d", datatype); @@ -955,10 +968,10 @@ ncclResult_t ncclMemAlloc(void** ptr, size_t size) { return ncclInternalError; } } catch (const mscclpp::CudaError& e) { - INFO(MSCCLPP_ALLOC, "Cuda error: %s", e.what()); + WARN("Cuda error: %s", e.what()); return ncclUnhandledCudaError; } catch (const mscclpp::CuError& e) { - INFO(MSCCLPP_ALLOC, "Cu error: %s", e.what()); + WARN("Cu error: %s", e.what()); return ncclUnhandledCudaError; } catch (const mscclpp::BaseError& e) { WARN("Base error: %s", e.what());