From affca7d9bc27c299b759cedf69a3ccc86d9c2e7b Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Sun, 27 Apr 2025 14:09:31 -0700 Subject: [PATCH] Add NVLS based fallback algo (#507) Add two nvls based fallback algo. allreduce9 is for nvls with zero copy. allreduce10 is for nvls need to copy to scratch buffer, do reduce operation then copy result back to result buffer. Perf number for allreduce9 ``` # out-of-place in-place # size count type redop root time algbw busbw #wrong time algbw busbw #wrong # (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s) 1024 256 float sum -1 5.45 0.19 0.33 0 5.35 0.19 0.33 0 2048 512 float sum -1 5.57 0.37 0.64 0 5.53 0.37 0.65 0 4096 1024 float sum -1 5.80 0.71 1.24 0 5.78 0.71 1.24 0 8192 2048 float sum -1 5.94 1.38 2.42 0 5.85 1.40 2.45 0 16384 4096 float sum -1 6.40 2.56 4.48 0 6.27 2.61 4.57 0 32768 8192 float sum -1 7.45 4.40 7.70 0 7.39 4.43 7.76 0 65536 16384 float sum -1 8.03 8.17 14.29 0 8.32 7.88 13.79 0 131072 32768 float sum -1 7.28 18.00 31.49 0 7.07 18.53 32.43 0 262144 65536 float sum -1 7.72 33.95 59.41 0 7.59 34.56 60.48 0 524288 131072 float sum -1 8.70 60.29 105.51 0 8.37 62.61 109.57 0 1048576 262144 float sum -1 10.56 99.26 173.70 0 10.32 101.64 177.87 0 2097152 524288 float sum -1 14.45 145.14 253.99 0 14.02 149.58 261.76 0 4194304 1048576 float sum -1 22.83 183.73 321.52 0 23.03 182.14 318.75 0 8388608 2097152 float sum -1 38.63 217.14 380.00 0 38.57 217.52 380.65 0 16777216 4194304 float sum -1 70.03 239.58 419.27 0 69.96 239.80 419.66 0 33554432 8388608 float sum -1 131.5 255.17 446.55 0 131.3 255.59 447.28 0 67108864 16777216 float sum -1 255.8 262.37 459.15 0 255.4 262.75 459.82 0 134217728 33554432 float sum -1 500.9 267.94 468.90 0 500.0 268.42 469.74 0 268435456 67108864 float sum -1 989.0 271.41 474.97 0 988.9 271.45 475.05 0 536870912 134217728 float sum -1 1967.4 272.88 477.54 0 1966.0 273.08 477.88 0 1073741824 268435456 float sum -1 3908.5 274.72 480.77 0 3904.6 274.99 481.24 0 # Out of bounds values : 0 OK # Avg bus bandwidth : 218.734 ``` Perf number for allreduce10 ``` # out-of-place in-place # size count type redop root time algbw busbw #wrong time algbw busbw #wrong # (B) (elements) (us) (GB/s) (GB/s) (us) (GB/s) (GB/s) 1024 256 float sum -1 5.60 0.18 0.32 0 5.52 0.19 0.32 0 2048 512 float sum -1 5.79 0.35 0.62 0 5.64 0.36 0.64 0 4096 1024 float sum -1 5.92 0.69 1.21 0 5.82 0.70 1.23 0 8192 2048 float sum -1 6.03 1.36 2.38 0 5.95 1.38 2.41 0 16384 4096 float sum -1 6.58 2.49 4.35 0 6.39 2.56 4.49 0 32768 8192 float sum -1 7.54 4.34 7.60 0 7.41 4.42 7.74 0 65536 16384 float sum -1 7.95 8.24 14.42 0 8.10 8.09 14.16 0 131072 32768 float sum -1 9.56 13.72 24.00 0 9.47 13.84 24.23 0 262144 65536 float sum -1 11.49 22.81 39.92 0 11.41 22.97 40.20 0 524288 131072 float sum -1 14.19 36.94 64.64 0 13.88 37.76 66.09 0 1048576 262144 float sum -1 19.10 54.89 96.06 0 18.98 55.24 96.67 0 2097152 524288 float sum -1 31.12 67.38 117.91 0 31.34 66.92 117.10 0 4194304 1048576 float sum -1 44.88 93.46 163.56 0 44.76 93.70 163.97 0 8388608 2097152 float sum -1 63.23 132.68 232.18 0 62.53 134.14 234.75 0 16777216 4194304 float sum -1 106.8 157.03 274.80 0 105.9 158.46 277.30 0 33554432 8388608 float sum -1 172.2 194.91 341.09 0 172.0 195.05 341.35 0 67108864 16777216 float sum -1 299.8 223.83 391.70 0 300.8 223.12 390.46 0 134217728 33554432 float sum -1 553.1 242.66 424.66 0 553.8 242.38 424.16 0 268435456 67108864 float sum -1 1056.1 254.18 444.82 0 1057.4 253.86 444.26 0 536870912 134217728 float sum -1 2064.0 260.11 455.20 0 2063.8 260.14 455.25 0 1073741824 268435456 float sum -1 4074.4 263.53 461.18 0 4065.8 264.09 462.16 0 # Out of bounds values : 0 OK # Avg bus bandwidth : 169.799 ``` --------- Co-authored-by: Sreevatsa Anantharamu Co-authored-by: Changho Hwang --- apps/nccl/src/allreduce.hpp | 224 ++++++++++++++++++++++++++++++++++-- apps/nccl/src/common.hpp | 4 + apps/nccl/src/nccl.cu | 118 +++++++++++++++++-- include/mscclpp/nvls.hpp | 2 +- src/nvls.cc | 7 +- 5 files changed, 336 insertions(+), 19 deletions(-) diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index eebc648a..cc6f23a0 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -164,6 +165,34 @@ __forceinline__ __device__ DataType cal_vectors(DataType a, DataType b) { return cal_vectors_helper(a, b); } +template +struct VectorType { + using type = T; + using nvls_type = T; + using nvls_type2 = T; +}; + +template <> +struct VectorType<__half> { + using type = __half2; + using nvls_type = uint4; + using nvls_type2 = uint1; +}; + +template <> +struct VectorType<__bfloat16> { + using type = __bfloat162; + using nvls_type = uint4; + using nvls_type2 = uint1; +}; + +template <> +struct VectorType { + using type = float; + using nvls_type = uint4; + using nvls_type2 = uint1; +}; + template __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, @@ -477,14 +506,185 @@ __global__ void __launch_bounds__(512, 1) } } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +template +MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* src, T* dst, uint32_t srcOffset, uint32_t dstOffset, + size_t size, int tid, int nThreads) { + using vectorType = typename VectorType::type; + using nvlsType = typename VectorType::nvls_type; + // nvls can only handle 4 bytes alignment + assert(size % sizeof(vectorType) == 0); + const size_t nInt4 = size / sizeof(nvlsType); + const size_t srcOffset4 = srcOffset / sizeof(nvlsType); + const size_t dstOffset4 = dstOffset / sizeof(nvlsType); + nvlsType* src4 = (nvlsType*)src; + nvlsType* dst4 = (nvlsType*)dst; + for (size_t idx = tid; idx < nInt4; idx += nThreads) { + nvlsType val; + mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (vectorType*)(src4 + srcOffset4 + idx)); + mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (vectorType*)(dst4 + dstOffset4 + idx)); + } + // handle rest of data + size_t processed = nInt4 * sizeof(nvlsType); + using nvlsType2 = typename VectorType::nvls_type2; + const size_t startIdx = (srcOffset + processed) / sizeof(nvlsType2); + const size_t endIdx = (dstOffset + size) / sizeof(nvlsType2); + for (size_t idx = tid + startIdx; idx < endIdx; idx += nThreads) { + nvlsType2 val; + mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (vectorType*)src + idx); + mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (vectorType*)dst + idx); + } +} +#endif + +template +__global__ void __launch_bounds__(1024, 1) + allreduce9([[maybe_unused]] mscclpp::DeviceHandle* memoryChannels, + [[maybe_unused]] mscclpp::DeviceHandle* multicast, + [[maybe_unused]] mscclpp::DeviceHandle* multicastOut, + [[maybe_unused]] size_t channelInOffset, [[maybe_unused]] size_t channelOutOffset, + [[maybe_unused]] size_t size, [[maybe_unused]] int rank) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + int nBlocks = gridDim.x; + int bid = blockIdx.x; + size_t sizePerRank = size / 8; + size_t sizePerBlock = sizePerRank / nBlocks; + size_t rankOffset = sizePerRank * rank; + size_t blockOffset = sizePerBlock * bid + rankOffset; + mscclpp::DeviceHandle* multicastPtr = multicast + bid; + mscclpp::DeviceHandle* multicastOutPtr = multicastOut + bid; + + const size_t chanOffset = (NRANKS_PER_NODE - 1) * blockIdx.x; + auto memoryChans = memoryChannels + chanOffset; + __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; + const int lid = threadIdx.x % WARP_SIZE; + if (lid < NRANKS_PER_NODE - 1) { + channels[lid] = memoryChans[lid]; + } + __syncwarp(); + if (threadIdx.x < NPEERS) { + channels[threadIdx.x].relaxedSignal(); + channels[threadIdx.x].relaxedWait(); + } + __syncthreads(); + T* src = (T*)multicastPtr->mcPtr; + T* dst = (T*)multicastOutPtr->mcPtr; + handleMultiLoadReduceStore(src, dst, blockOffset + channelInOffset, blockOffset + channelOutOffset, sizePerBlock, + threadIdx.x, blockDim.x); + __syncthreads(); + if (threadIdx.x < NPEERS) { + channels[threadIdx.x].relaxedSignal(); + channels[threadIdx.x].relaxedWait(); + } +#endif +} + +template +__global__ void __launch_bounds__(1024, 1) + allreduce10([[maybe_unused]] const void* src, [[maybe_unused]] void* scratch, [[maybe_unused]] void* dst, + [[maybe_unused]] mscclpp::DeviceHandle* memoryChannels, + [[maybe_unused]] mscclpp::DeviceHandle* multicast, + [[maybe_unused]] size_t size, [[maybe_unused]] int rank) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + int nBlocks = gridDim.x; + int nBlocksPerNvlsConn = nBlocks / NUM_NVLS_CONNECTION; + int bid = blockIdx.x; + size_t sizePerRank = size / NRANKS_PER_NODE; + constexpr size_t scratchSizePerRank = SCRATCH_SIZE / NRANKS_PER_NODE; + const size_t maxSizePerBlock = (sizePerRank + nBlocks - 1) / nBlocks; + size_t start = bid * maxSizePerBlock; + size_t end = min(start + maxSizePerBlock, sizePerRank); + size_t sizePerBlock = end - start; + mscclpp::DeviceHandle* multicastPtr = + multicast + bid / nBlocksPerNvlsConn; + size_t copyPerIter = 1024 * 16; + if (sizePerBlock >= 1024 * 64) { + copyPerIter = 1024 * 32; + } + size_t scratchSizePerBlock = (scratchSizePerRank / nBlocks) / copyPerIter * copyPerIter; + size_t blockScratchOffset = scratchSizePerBlock * bid + scratchSizePerRank * rank; + constexpr int NCOPY_WARPS = 14; + constexpr int NREDUCE_WARPS = 4; + constexpr int NRECV_COPY_WARPS = 14; + constexpr int endCopyWid = NCOPY_WARPS; + constexpr int startRecvCopyWid = NCOPY_WARPS; + constexpr int endRecvCopyWid = NCOPY_WARPS + NRECV_COPY_WARPS; + constexpr int endReduceWid = NCOPY_WARPS + NREDUCE_WARPS + NRECV_COPY_WARPS; + const int warpId = threadIdx.x / WARP_SIZE; + size_t nIter = sizePerBlock / copyPerIter; + size_t lastIterSize = copyPerIter; + if (sizePerBlock % copyPerIter != 0) { + nIter += 1; + lastIterSize = sizePerBlock % copyPerIter; + } + + const size_t chanOffset = (NRANKS_PER_NODE - 1) * blockIdx.x * 2; + auto memoryChans = memoryChannels + chanOffset; + __shared__ mscclpp::DeviceHandle channels[(NRANKS_PER_NODE - 1) * 2]; + const int lid = threadIdx.x % WARP_SIZE; + if (lid < (NRANKS_PER_NODE - 1) * 2) { + channels[lid] = memoryChans[lid]; + } + __syncwarp(); + for (int it = 0; it < nIter; it++) { + const size_t iterSize = (it == nIter - 1) ? lastIterSize : copyPerIter; + if (warpId < endCopyWid) { + int tidInCopy = threadIdx.x; + for (int i = 0; i < NRANKS_PER_NODE; i++) { + size_t offset = i * sizePerRank + sizePerBlock * bid + it * copyPerIter; + size_t offsetScratch = + i * scratchSizePerRank + scratchSizePerBlock * bid + (it * copyPerIter) % scratchSizePerBlock; + char* srcData = (char*)src + offset; + char* dstData = (char*)scratch + offsetScratch; + mscclpp::copy(dstData, srcData, iterSize, tidInCopy, NCOPY_WARPS * WARP_SIZE); + } + asm volatile("bar.sync %0, %1;" ::"r"(0), "r"(NCOPY_WARPS * WARP_SIZE) : "memory"); + if (tidInCopy < NPEERS) { + channels[tidInCopy].signal(); + channels[tidInCopy].wait(); + } + asm volatile("bar.sync %0, %1;" ::"r"(1), "r"((NCOPY_WARPS + NREDUCE_WARPS) * WARP_SIZE) : "memory"); + } + if (warpId >= endRecvCopyWid && warpId < endReduceWid) { + int tidInReduce = threadIdx.x - endRecvCopyWid * WARP_SIZE; + asm volatile("bar.sync %0, %1;" ::"r"(1), "r"((NCOPY_WARPS + NREDUCE_WARPS) * WARP_SIZE) : "memory"); + T* mcBuff = (T*)multicastPtr->mcPtr; + size_t offset = blockScratchOffset + (it * copyPerIter) % scratchSizePerBlock; + handleMultiLoadReduceStore(mcBuff, mcBuff, offset, offset, iterSize, tidInReduce, NREDUCE_WARPS * WARP_SIZE); + asm volatile("bar.sync %0, %1;" ::"r"(2), "r"((NRECV_COPY_WARPS + NREDUCE_WARPS) * WARP_SIZE) : "memory"); + } + if (warpId >= startRecvCopyWid && warpId < endRecvCopyWid) { + int tidInRecvCopy = threadIdx.x - startRecvCopyWid * WARP_SIZE; + asm volatile("bar.sync %0, %1;" ::"r"(2), "r"((NRECV_COPY_WARPS + NREDUCE_WARPS) * WARP_SIZE) : "memory"); + if (tidInRecvCopy < NPEERS) { + channels[tidInRecvCopy + NPEERS].signal(); + channels[tidInRecvCopy + NPEERS].wait(); + } + asm volatile("bar.sync %0, %1;" ::"r"(3), "r"((NRECV_COPY_WARPS)*WARP_SIZE) : "memory"); + for (int i = 0; i < NRANKS_PER_NODE; i++) { + size_t offset = i * sizePerRank + sizePerBlock * bid + it * copyPerIter; + size_t offsetScratch = + i * scratchSizePerRank + scratchSizePerBlock * bid + (it * copyPerIter) % scratchSizePerBlock; + char* srcData = (char*)scratch + offsetScratch; + char* dstData = (char*)dst + offset; + mscclpp::copy(dstData, srcData, iterSize, tidInRecvCopy, NRECV_COPY_WARPS * WARP_SIZE); + } + } + } +#endif +} + template cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff, mscclpp::DeviceHandle* memoryChannels, - mscclpp::DeviceHandle* memoryOutChannels, size_t channelInOffset, - size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, cudaStream_t stream, uint32_t* deviceFlag7, uint32_t* deviceFlag28, - uint32_t* deviceFlag56, uint32_t numScratchBuff) { - uint32_t* deviceFlag; + mscclpp::DeviceHandle* memoryOutChannels, + mscclpp::DeviceHandle* nvlsChannels, + mscclpp::DeviceHandle* nvlsOutChannels, + size_t channelInOffset, size_t channelOutOffset, size_t channelScratchOffset, int rank, + int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream, uint32_t* deviceFlag7, + uint32_t* deviceFlag28, uint32_t* deviceFlag56, uint32_t numScratchBuff) { + bool useNvlsWithZeroCopy = mscclpp::isNvlsSupported() && !mscclppDisableChannelCache; + if (sizeof(T) * nelems < worldSize * sizeof(int)) { int nBlocks = 7; int nThreadsPerBlock = 32; @@ -497,10 +697,10 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff, allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, deviceFlag28, numScratchBuff); - } else if (sizeof(T) * nelems <= (1 << 20)) { + } else if (sizeof(T) * nelems <= (1 << 16) || (sizeof(T) * nelems <= (1 << 20) && !useNvlsWithZeroCopy)) { int nBlocks = 28; int nThreadsPerBlock = 1024; - deviceFlag = deviceFlag28; + uint32_t* deviceFlag = deviceFlag28; if (nelems >= 8192) { nBlocks = 56; nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024; @@ -517,6 +717,16 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff, (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff); #endif + } else if (useNvlsWithZeroCopy) { + int nBlocks = 8; + int nThreadsPerBlock = 1024; + allreduce9<<>>( + memoryChannels, nvlsChannels, nvlsOutChannels, channelInOffset, channelOutOffset, nelems * sizeof(T), rank); + } else if (mscclpp::isNvlsSupported()) { + int nBlocks = 32; + int nThreadsPerBlock = 1024; + allreduce10<<>>(buff, scratch, resultBuff, memoryChannels, nvlsChannels, + nelems * sizeof(T), rank); } else { int nBlocks = 35; int nThreadsPerBlock = 512; diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp index 015e0a2f..f6db6ea7 100644 --- a/apps/nccl/src/common.hpp +++ b/apps/nccl/src/common.hpp @@ -5,6 +5,7 @@ #define NCCL_COMMON_HPP_ #include +#include #if defined(__HIP_PLATFORM_AMD__) #define WARP_SIZE 64 @@ -13,10 +14,13 @@ #define WARP_SIZE 32 #endif +constexpr int NUM_NVLS_CONNECTION = 8; + constexpr int NRANKS_PER_NODE = 8; constexpr int NPEERS = 7; constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB +static bool mscclppDisableChannelCache = mscclpp::env()->disableChannelCache; __device__ mscclpp::DeviceSyncer deviceSyncer; diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 739d3ec8..a8962bbb 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -10,7 +10,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -37,6 +39,7 @@ } while (0) #define NUM_CHANNELS_PER_CONNECTION 64 +static constexpr size_t NVLS_BUFFER_SIZE = (1 << 30); typedef enum mscclppNcclDlopenErr { dlopenSuccess = 0, @@ -139,7 +142,6 @@ static inline int mscclppNcclInFallbackList(const char* collOps, const char* fal // Declare the global map to store associations between raw pointer and shared pointer static std::unordered_map> ptrMap; -static bool mscclppDisableChannelCache = mscclpp::env()->disableChannelCache; struct channelKey { const void* buff; @@ -172,6 +174,11 @@ struct ChannelInfo { std::shared_ptr> memoryChannelDeviceHandles; }; +struct NvlsChannelInfo { + std::vector nvlsChannels; + std::shared_ptr> nvlsChannelDeviceHandles; +}; + struct splitCommInfo { int color; int key; @@ -181,6 +188,8 @@ struct splitCommInfo { struct ncclComm { std::shared_ptr comm; std::vector> connections; + std::vector> nvlsConnections; + std::vector> nvlsConnectionsOut; std::vector> memorySemaphores; std::shared_ptr executor; std::unordered_map> executionPlans; @@ -188,6 +197,7 @@ struct ncclComm { std::unordered_map channelInInfos; std::unordered_map channelOutInfos; std::unordered_map channelScratchInfos; + std::unordered_map channelNvlsInfos; std::shared_ptr scratchBuff; std::vector remoteScratchRegMemories; std::vector channelInfos; @@ -287,6 +297,34 @@ static std::vector setupMemoryChannels( return channels; } +static std::vector> setupNvlsConnections(ncclComm_t comm, size_t size) { + // for nvls connection + std::vector> nvlsConnections; + int nRanks = comm->comm->bootstrap()->getNranks(); + std::vector ranks; + for (int i = 0; i < nRanks; i++) { + ranks.push_back(i); + } + for (int i = 0; i < NUM_NVLS_CONNECTION; i++) { + std::shared_ptr nvlsConnection = mscclpp::connectNvlsCollective(comm->comm, ranks, size); + nvlsConnections.push_back(nvlsConnection); + } + return nvlsConnections; +} + +static std::vector setupNvlsChannels( + std::vector> conns, void* buffer, size_t bufferSize) { + std::vector channels; + + for (size_t idx = 0; idx < NUM_NVLS_CONNECTION; ++idx) { + std::shared_ptr nvlsConnection = conns[idx]; + mscclpp::NvlsConnection::DeviceMulticastPointer deviceMulticastPointer = + nvlsConnection->bindAllocatedMemory((CUdeviceptr)buffer, bufferSize); + channels.push_back(deviceMulticastPointer); + } + return channels; +} + static std::pair loadExecutionPlan(const std::string& filename) { std::shared_ptr plan = std::make_shared(filename); std::string collective = plan->collective(); @@ -307,6 +345,21 @@ static std::shared_ptr> setupMemor return ptr; } +static std::shared_ptr> +setupNvlsChannelDeviceHandles(const std::vector& nvlsChannels) { + std::shared_ptr> ptr = + mscclpp::detail::gpuCallocShared>( + nvlsChannels.size()); + std::vector> nvlsChannelDeviceHandles; + std::transform(nvlsChannels.begin(), nvlsChannels.end(), std::back_inserter(nvlsChannelDeviceHandles), + [](const mscclpp::NvlsConnection::DeviceMulticastPointer& nvlsChannel) { + return mscclpp::deviceHandle(nvlsChannel); + }); + mscclpp::gpuMemcpy>( + ptr.get(), nvlsChannelDeviceHandles.data(), nvlsChannelDeviceHandles.size(), cudaMemcpyHostToDevice); + return ptr; +} + static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) { // FallBack for single node @@ -337,10 +390,49 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, channelKey recvKey{(void*)recvBasePtr, recvBytes}; mscclpp::DeviceHandle* memoryChannels = nullptr; mscclpp::DeviceHandle* memoryOutChannels = nullptr; + mscclpp::DeviceHandle* nvlsChannels = nullptr; + mscclpp::DeviceHandle* nvlsOutChannels = nullptr; size_t bytes = count * ncclTypeSize(datatype); + bool useNvlsWithZeroCopy = mscclpp::isNvlsSupported() && !mscclppDisableChannelCache; + bool useNvlsWithCopy = mscclpp::isNvlsSupported() && mscclppDisableChannelCache; // Creating the channels - if (count * ncclTypeSize(datatype) <= (1 << 20)) { + if (useNvlsWithZeroCopy) { + auto nvlsIt = comm->channelNvlsInfos.find(sendKey); + if (nvlsIt == comm->channelNvlsInfos.end()) { + std::vector channels = + setupNvlsChannels(comm->nvlsConnections, (void*)sendBasePtr, sendBytes); + NvlsChannelInfo channelInfo{channels, setupNvlsChannelDeviceHandles(channels)}; + nvlsIt = comm->channelNvlsInfos.emplace(sendKey, channelInfo).first; + } + nvlsChannels = nvlsIt->second.nvlsChannelDeviceHandles.get(); + if (recvbuff != sendbuff) { + auto nvlsOutIt = comm->channelNvlsInfos.find(recvKey); + if (nvlsOutIt == comm->channelNvlsInfos.end()) { + std::vector channels = + setupNvlsChannels(comm->nvlsConnectionsOut, (void*)recvBasePtr, recvBytes); + NvlsChannelInfo channelInfo{channels, setupNvlsChannelDeviceHandles(channels)}; + nvlsOutIt = comm->channelNvlsInfos.emplace(recvKey, channelInfo).first; + } + nvlsOutChannels = nvlsOutIt->second.nvlsChannelDeviceHandles.get(); + } else { + nvlsOutChannels = nvlsChannels; + } + } + + if (useNvlsWithCopy) { + channelKey sendKey{(void*)(comm->scratchBuff.get()), SCRATCH_SIZE}; + auto nvlsIt = comm->channelNvlsInfos.find(sendKey); + if (nvlsIt == comm->channelNvlsInfos.end()) { + std::vector channels = + setupNvlsChannels(comm->nvlsConnections, (void*)comm->scratchBuff.get(), SCRATCH_SIZE); + NvlsChannelInfo channelInfo{channels, setupNvlsChannelDeviceHandles(channels)}; + nvlsIt = comm->channelNvlsInfos.emplace(sendKey, channelInfo).first; + } + nvlsChannels = nvlsIt->second.nvlsChannelDeviceHandles.get(); + } + + if (count * ncclTypeSize(datatype) <= (1 << 20) || mscclpp::isNvlsSupported()) { auto sendIt = comm->channelScratchInfos.find(sendKey); if (sendIt == comm->channelScratchInfos.end()) { std::vector channels = @@ -386,8 +478,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, Op reduceOp = getReduceOp(op); std::function*, - mscclpp::DeviceHandle*, size_t, size_t, size_t, int, int, int, - size_t, cudaStream_t, uint32_t*, uint32_t*, uint32_t*, int)> + mscclpp::DeviceHandle*, + mscclpp::DeviceHandle*, + mscclpp::DeviceHandle*, size_t, size_t, + size_t, int, int, int, size_t, cudaStream_t, uint32_t*, uint32_t*, uint32_t*, int)> allreduceFunc; if (reduceOp == SUM) { if (datatype == ncclFloat16) { @@ -416,11 +510,11 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, return ncclInvalidArgument; } } - CUDACHECK(allreduceFunc(sendbuff, comm->scratchBuff.get(), recvbuff, memoryChannels, memoryOutChannels, offsetIn, - offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream, (uint32_t*)comm->deviceFlag7.get(), - (uint32_t*)comm->deviceFlag28.get(), (uint32_t*)comm->deviceFlag56.get(), - comm->numScratchBuff)); + CUDACHECK(allreduceFunc(sendbuff, comm->scratchBuff.get(), recvbuff, memoryChannels, memoryOutChannels, nvlsChannels, + nvlsOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, + (uint32_t*)comm->deviceFlag7.get(), (uint32_t*)comm->deviceFlag28.get(), + (uint32_t*)comm->deviceFlag56.get(), comm->numScratchBuff)); return ncclSuccess; } @@ -447,7 +541,7 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff)); size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; channelKey recvKey{(void*)recvBasePtr, recvBytes}; - [[maybe_unused]] channelKey sendKey{(void*)sendBasePtr, sendBytes}; + [[maybe_unused]] channelKey sendKey{(void*)comm->scratchBuff.get(), SCRATCH_SIZE}; int rank = comm->comm->bootstrap()->getRank(); int nRank = comm->comm->bootstrap()->getNranks(); mscclpp::DeviceHandle* memoryChannels = nullptr; @@ -533,6 +627,10 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt mscclppComm->setup(); commPtr->connections = std::move(connections); + if (mscclpp::isNvlsSupported()) { + commPtr->nvlsConnections = setupNvlsConnections(commPtr, NVLS_BUFFER_SIZE); + commPtr->nvlsConnectionsOut = setupNvlsConnections(commPtr, NVLS_BUFFER_SIZE); + } commPtr->memorySemaphores = std::move(memorySemaphores); commPtr->buffFlag = 0; commPtr->numScratchBuff = 2; diff --git a/include/mscclpp/nvls.hpp b/include/mscclpp/nvls.hpp index 90915cf7..25d5d7f1 100644 --- a/include/mscclpp/nvls.hpp +++ b/include/mscclpp/nvls.hpp @@ -30,7 +30,7 @@ class NvlsConnection { using DeviceHandle = DeviceMulticastPointerDeviceHandle; DeviceMulticastPointer(void* devicePtr, std::shared_ptr mcPtr, size_t bufferSize) : devicePtr_(devicePtr), mcPtr_(mcPtr), bufferSize_(bufferSize) {} - DeviceHandle deviceHandle(); + DeviceHandle deviceHandle() const; void* getDevicePtr(); friend class NvlsConnection; diff --git a/src/nvls.cc b/src/nvls.cc index c8cbd8cd..188028aa 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -198,6 +198,11 @@ std::shared_ptr NvlsConnection::Impl::bindMemory(CUdeviceptr devicePtr, si ErrorCode::InvalidUsage); } + if ((uintptr_t)devicePtr % minMcGran_ != 0) { + WARN("NVLS connection tried to bind a buffer that is not aligned to the minimum granularity"); + throw Error("This NVLS connection tried to bind a buffer that is not aligned to the minimum granularity", + ErrorCode::InvalidUsage); + } devBuffSize = ((devBuffSize + minMcGran_ - 1) / minMcGran_) * minMcGran_; size_t offset = allocateBuffer(devBuffSize); MSCCLPP_CUTHROW(cuMulticastBindAddr(mcHandle_, offset /*mcOffset*/, devicePtr, devBuffSize, 0)); @@ -265,7 +270,7 @@ NvlsConnection::DeviceMulticastPointer NvlsConnection::bindAllocatedMemory(CUdev return DeviceMulticastPointer((void*)devicePtr, mcPtr, size); } -NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() { +NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() const { NvlsConnection::DeviceMulticastPointer::DeviceHandle device; device.devicePtr = this->devicePtr_; device.mcPtr = this->mcPtr_.get();