diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index eebc648a..cc6f23a0 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -164,6 +165,34 @@ __forceinline__ __device__ DataType cal_vectors(DataType a, DataType b) { return cal_vectors_helper(a, b); } +template +struct VectorType { + using type = T; + using nvls_type = T; + using nvls_type2 = T; +}; + +template <> +struct VectorType<__half> { + using type = __half2; + using nvls_type = uint4; + using nvls_type2 = uint1; +}; + +template <> +struct VectorType<__bfloat16> { + using type = __bfloat162; + using nvls_type = uint4; + using nvls_type2 = uint1; +}; + +template <> +struct VectorType { + using type = float; + using nvls_type = uint4; + using nvls_type2 = uint1; +}; + template __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, @@ -477,14 +506,185 @@ __global__ void __launch_bounds__(512, 1) } } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +template +MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* src, T* dst, uint32_t srcOffset, uint32_t dstOffset, + size_t size, int tid, int nThreads) { + using vectorType = typename VectorType::type; + using nvlsType = typename VectorType::nvls_type; + // nvls can only handle 4 bytes alignment + assert(size % sizeof(vectorType) == 0); + const size_t nInt4 = size / sizeof(nvlsType); + const size_t srcOffset4 = srcOffset / sizeof(nvlsType); + const size_t dstOffset4 = dstOffset / sizeof(nvlsType); + nvlsType* src4 = (nvlsType*)src; + nvlsType* dst4 = (nvlsType*)dst; + for (size_t idx = tid; idx < nInt4; idx += nThreads) { + nvlsType val; + mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (vectorType*)(src4 + srcOffset4 + idx)); + mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (vectorType*)(dst4 + dstOffset4 + idx)); + } + // handle rest of data + size_t processed = nInt4 * sizeof(nvlsType); + using nvlsType2 = typename VectorType::nvls_type2; + const size_t startIdx = (srcOffset + processed) / sizeof(nvlsType2); + const size_t endIdx = (dstOffset + size) / sizeof(nvlsType2); + for (size_t idx = tid + startIdx; idx < endIdx; idx += nThreads) { + nvlsType2 val; + mscclpp::DeviceMulticastPointerDeviceHandle::multimemLoadReduce(val, (vectorType*)src + idx); + mscclpp::DeviceMulticastPointerDeviceHandle::multimemStore(val, (vectorType*)dst + idx); + } +} +#endif + +template +__global__ void __launch_bounds__(1024, 1) + allreduce9([[maybe_unused]] mscclpp::DeviceHandle* memoryChannels, + [[maybe_unused]] mscclpp::DeviceHandle* multicast, + [[maybe_unused]] mscclpp::DeviceHandle* multicastOut, + [[maybe_unused]] size_t channelInOffset, [[maybe_unused]] size_t channelOutOffset, + [[maybe_unused]] size_t size, [[maybe_unused]] int rank) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + int nBlocks = gridDim.x; + int bid = blockIdx.x; + size_t sizePerRank = size / 8; + size_t sizePerBlock = sizePerRank / nBlocks; + size_t rankOffset = sizePerRank * rank; + size_t blockOffset = sizePerBlock * bid + rankOffset; + mscclpp::DeviceHandle* multicastPtr = multicast + bid; + mscclpp::DeviceHandle* multicastOutPtr = multicastOut + bid; + + const size_t chanOffset = (NRANKS_PER_NODE - 1) * blockIdx.x; + auto memoryChans = memoryChannels + chanOffset; + __shared__ mscclpp::DeviceHandle channels[NRANKS_PER_NODE - 1]; + const int lid = threadIdx.x % WARP_SIZE; + if (lid < NRANKS_PER_NODE - 1) { + channels[lid] = memoryChans[lid]; + } + __syncwarp(); + if (threadIdx.x < NPEERS) { + channels[threadIdx.x].relaxedSignal(); + channels[threadIdx.x].relaxedWait(); + } + __syncthreads(); + T* src = (T*)multicastPtr->mcPtr; + T* dst = (T*)multicastOutPtr->mcPtr; + handleMultiLoadReduceStore(src, dst, blockOffset + channelInOffset, blockOffset + channelOutOffset, sizePerBlock, + threadIdx.x, blockDim.x); + __syncthreads(); + if (threadIdx.x < NPEERS) { + channels[threadIdx.x].relaxedSignal(); + channels[threadIdx.x].relaxedWait(); + } +#endif +} + +template +__global__ void __launch_bounds__(1024, 1) + allreduce10([[maybe_unused]] const void* src, [[maybe_unused]] void* scratch, [[maybe_unused]] void* dst, + [[maybe_unused]] mscclpp::DeviceHandle* memoryChannels, + [[maybe_unused]] mscclpp::DeviceHandle* multicast, + [[maybe_unused]] size_t size, [[maybe_unused]] int rank) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + int nBlocks = gridDim.x; + int nBlocksPerNvlsConn = nBlocks / NUM_NVLS_CONNECTION; + int bid = blockIdx.x; + size_t sizePerRank = size / NRANKS_PER_NODE; + constexpr size_t scratchSizePerRank = SCRATCH_SIZE / NRANKS_PER_NODE; + const size_t maxSizePerBlock = (sizePerRank + nBlocks - 1) / nBlocks; + size_t start = bid * maxSizePerBlock; + size_t end = min(start + maxSizePerBlock, sizePerRank); + size_t sizePerBlock = end - start; + mscclpp::DeviceHandle* multicastPtr = + multicast + bid / nBlocksPerNvlsConn; + size_t copyPerIter = 1024 * 16; + if (sizePerBlock >= 1024 * 64) { + copyPerIter = 1024 * 32; + } + size_t scratchSizePerBlock = (scratchSizePerRank / nBlocks) / copyPerIter * copyPerIter; + size_t blockScratchOffset = scratchSizePerBlock * bid + scratchSizePerRank * rank; + constexpr int NCOPY_WARPS = 14; + constexpr int NREDUCE_WARPS = 4; + constexpr int NRECV_COPY_WARPS = 14; + constexpr int endCopyWid = NCOPY_WARPS; + constexpr int startRecvCopyWid = NCOPY_WARPS; + constexpr int endRecvCopyWid = NCOPY_WARPS + NRECV_COPY_WARPS; + constexpr int endReduceWid = NCOPY_WARPS + NREDUCE_WARPS + NRECV_COPY_WARPS; + const int warpId = threadIdx.x / WARP_SIZE; + size_t nIter = sizePerBlock / copyPerIter; + size_t lastIterSize = copyPerIter; + if (sizePerBlock % copyPerIter != 0) { + nIter += 1; + lastIterSize = sizePerBlock % copyPerIter; + } + + const size_t chanOffset = (NRANKS_PER_NODE - 1) * blockIdx.x * 2; + auto memoryChans = memoryChannels + chanOffset; + __shared__ mscclpp::DeviceHandle channels[(NRANKS_PER_NODE - 1) * 2]; + const int lid = threadIdx.x % WARP_SIZE; + if (lid < (NRANKS_PER_NODE - 1) * 2) { + channels[lid] = memoryChans[lid]; + } + __syncwarp(); + for (int it = 0; it < nIter; it++) { + const size_t iterSize = (it == nIter - 1) ? lastIterSize : copyPerIter; + if (warpId < endCopyWid) { + int tidInCopy = threadIdx.x; + for (int i = 0; i < NRANKS_PER_NODE; i++) { + size_t offset = i * sizePerRank + sizePerBlock * bid + it * copyPerIter; + size_t offsetScratch = + i * scratchSizePerRank + scratchSizePerBlock * bid + (it * copyPerIter) % scratchSizePerBlock; + char* srcData = (char*)src + offset; + char* dstData = (char*)scratch + offsetScratch; + mscclpp::copy(dstData, srcData, iterSize, tidInCopy, NCOPY_WARPS * WARP_SIZE); + } + asm volatile("bar.sync %0, %1;" ::"r"(0), "r"(NCOPY_WARPS * WARP_SIZE) : "memory"); + if (tidInCopy < NPEERS) { + channels[tidInCopy].signal(); + channels[tidInCopy].wait(); + } + asm volatile("bar.sync %0, %1;" ::"r"(1), "r"((NCOPY_WARPS + NREDUCE_WARPS) * WARP_SIZE) : "memory"); + } + if (warpId >= endRecvCopyWid && warpId < endReduceWid) { + int tidInReduce = threadIdx.x - endRecvCopyWid * WARP_SIZE; + asm volatile("bar.sync %0, %1;" ::"r"(1), "r"((NCOPY_WARPS + NREDUCE_WARPS) * WARP_SIZE) : "memory"); + T* mcBuff = (T*)multicastPtr->mcPtr; + size_t offset = blockScratchOffset + (it * copyPerIter) % scratchSizePerBlock; + handleMultiLoadReduceStore(mcBuff, mcBuff, offset, offset, iterSize, tidInReduce, NREDUCE_WARPS * WARP_SIZE); + asm volatile("bar.sync %0, %1;" ::"r"(2), "r"((NRECV_COPY_WARPS + NREDUCE_WARPS) * WARP_SIZE) : "memory"); + } + if (warpId >= startRecvCopyWid && warpId < endRecvCopyWid) { + int tidInRecvCopy = threadIdx.x - startRecvCopyWid * WARP_SIZE; + asm volatile("bar.sync %0, %1;" ::"r"(2), "r"((NRECV_COPY_WARPS + NREDUCE_WARPS) * WARP_SIZE) : "memory"); + if (tidInRecvCopy < NPEERS) { + channels[tidInRecvCopy + NPEERS].signal(); + channels[tidInRecvCopy + NPEERS].wait(); + } + asm volatile("bar.sync %0, %1;" ::"r"(3), "r"((NRECV_COPY_WARPS)*WARP_SIZE) : "memory"); + for (int i = 0; i < NRANKS_PER_NODE; i++) { + size_t offset = i * sizePerRank + sizePerBlock * bid + it * copyPerIter; + size_t offsetScratch = + i * scratchSizePerRank + scratchSizePerBlock * bid + (it * copyPerIter) % scratchSizePerBlock; + char* srcData = (char*)scratch + offsetScratch; + char* dstData = (char*)dst + offset; + mscclpp::copy(dstData, srcData, iterSize, tidInRecvCopy, NRECV_COPY_WARPS * WARP_SIZE); + } + } + } +#endif +} + template cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff, mscclpp::DeviceHandle* memoryChannels, - mscclpp::DeviceHandle* memoryOutChannels, size_t channelInOffset, - size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, - size_t nelems, cudaStream_t stream, uint32_t* deviceFlag7, uint32_t* deviceFlag28, - uint32_t* deviceFlag56, uint32_t numScratchBuff) { - uint32_t* deviceFlag; + mscclpp::DeviceHandle* memoryOutChannels, + mscclpp::DeviceHandle* nvlsChannels, + mscclpp::DeviceHandle* nvlsOutChannels, + size_t channelInOffset, size_t channelOutOffset, size_t channelScratchOffset, int rank, + int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream, uint32_t* deviceFlag7, + uint32_t* deviceFlag28, uint32_t* deviceFlag56, uint32_t numScratchBuff) { + bool useNvlsWithZeroCopy = mscclpp::isNvlsSupported() && !mscclppDisableChannelCache; + if (sizeof(T) * nelems < worldSize * sizeof(int)) { int nBlocks = 7; int nThreadsPerBlock = 32; @@ -497,10 +697,10 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff, allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, deviceFlag28, numScratchBuff); - } else if (sizeof(T) * nelems <= (1 << 20)) { + } else if (sizeof(T) * nelems <= (1 << 16) || (sizeof(T) * nelems <= (1 << 20) && !useNvlsWithZeroCopy)) { int nBlocks = 28; int nThreadsPerBlock = 1024; - deviceFlag = deviceFlag28; + uint32_t* deviceFlag = deviceFlag28; if (nelems >= 8192) { nBlocks = 56; nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024; @@ -517,6 +717,16 @@ cudaError_t allreduce(const void* buff, void* scratch, void* resultBuff, (T*)buff, (T*)scratch, (T*)resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize, nelems, deviceFlag, numScratchBuff); #endif + } else if (useNvlsWithZeroCopy) { + int nBlocks = 8; + int nThreadsPerBlock = 1024; + allreduce9<<>>( + memoryChannels, nvlsChannels, nvlsOutChannels, channelInOffset, channelOutOffset, nelems * sizeof(T), rank); + } else if (mscclpp::isNvlsSupported()) { + int nBlocks = 32; + int nThreadsPerBlock = 1024; + allreduce10<<>>(buff, scratch, resultBuff, memoryChannels, nvlsChannels, + nelems * sizeof(T), rank); } else { int nBlocks = 35; int nThreadsPerBlock = 512; diff --git a/apps/nccl/src/common.hpp b/apps/nccl/src/common.hpp index 015e0a2f..f6db6ea7 100644 --- a/apps/nccl/src/common.hpp +++ b/apps/nccl/src/common.hpp @@ -5,6 +5,7 @@ #define NCCL_COMMON_HPP_ #include +#include #if defined(__HIP_PLATFORM_AMD__) #define WARP_SIZE 64 @@ -13,10 +14,13 @@ #define WARP_SIZE 32 #endif +constexpr int NUM_NVLS_CONNECTION = 8; + constexpr int NRANKS_PER_NODE = 8; constexpr int NPEERS = 7; constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB +static bool mscclppDisableChannelCache = mscclpp::env()->disableChannelCache; __device__ mscclpp::DeviceSyncer deviceSyncer; diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 739d3ec8..a8962bbb 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -10,7 +10,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -37,6 +39,7 @@ } while (0) #define NUM_CHANNELS_PER_CONNECTION 64 +static constexpr size_t NVLS_BUFFER_SIZE = (1 << 30); typedef enum mscclppNcclDlopenErr { dlopenSuccess = 0, @@ -139,7 +142,6 @@ static inline int mscclppNcclInFallbackList(const char* collOps, const char* fal // Declare the global map to store associations between raw pointer and shared pointer static std::unordered_map> ptrMap; -static bool mscclppDisableChannelCache = mscclpp::env()->disableChannelCache; struct channelKey { const void* buff; @@ -172,6 +174,11 @@ struct ChannelInfo { std::shared_ptr> memoryChannelDeviceHandles; }; +struct NvlsChannelInfo { + std::vector nvlsChannels; + std::shared_ptr> nvlsChannelDeviceHandles; +}; + struct splitCommInfo { int color; int key; @@ -181,6 +188,8 @@ struct splitCommInfo { struct ncclComm { std::shared_ptr comm; std::vector> connections; + std::vector> nvlsConnections; + std::vector> nvlsConnectionsOut; std::vector> memorySemaphores; std::shared_ptr executor; std::unordered_map> executionPlans; @@ -188,6 +197,7 @@ struct ncclComm { std::unordered_map channelInInfos; std::unordered_map channelOutInfos; std::unordered_map channelScratchInfos; + std::unordered_map channelNvlsInfos; std::shared_ptr scratchBuff; std::vector remoteScratchRegMemories; std::vector channelInfos; @@ -287,6 +297,34 @@ static std::vector setupMemoryChannels( return channels; } +static std::vector> setupNvlsConnections(ncclComm_t comm, size_t size) { + // for nvls connection + std::vector> nvlsConnections; + int nRanks = comm->comm->bootstrap()->getNranks(); + std::vector ranks; + for (int i = 0; i < nRanks; i++) { + ranks.push_back(i); + } + for (int i = 0; i < NUM_NVLS_CONNECTION; i++) { + std::shared_ptr nvlsConnection = mscclpp::connectNvlsCollective(comm->comm, ranks, size); + nvlsConnections.push_back(nvlsConnection); + } + return nvlsConnections; +} + +static std::vector setupNvlsChannels( + std::vector> conns, void* buffer, size_t bufferSize) { + std::vector channels; + + for (size_t idx = 0; idx < NUM_NVLS_CONNECTION; ++idx) { + std::shared_ptr nvlsConnection = conns[idx]; + mscclpp::NvlsConnection::DeviceMulticastPointer deviceMulticastPointer = + nvlsConnection->bindAllocatedMemory((CUdeviceptr)buffer, bufferSize); + channels.push_back(deviceMulticastPointer); + } + return channels; +} + static std::pair loadExecutionPlan(const std::string& filename) { std::shared_ptr plan = std::make_shared(filename); std::string collective = plan->collective(); @@ -307,6 +345,21 @@ static std::shared_ptr> setupMemor return ptr; } +static std::shared_ptr> +setupNvlsChannelDeviceHandles(const std::vector& nvlsChannels) { + std::shared_ptr> ptr = + mscclpp::detail::gpuCallocShared>( + nvlsChannels.size()); + std::vector> nvlsChannelDeviceHandles; + std::transform(nvlsChannels.begin(), nvlsChannels.end(), std::back_inserter(nvlsChannelDeviceHandles), + [](const mscclpp::NvlsConnection::DeviceMulticastPointer& nvlsChannel) { + return mscclpp::deviceHandle(nvlsChannel); + }); + mscclpp::gpuMemcpy>( + ptr.get(), nvlsChannelDeviceHandles.data(), nvlsChannelDeviceHandles.size(), cudaMemcpyHostToDevice); + return ptr; +} + static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) { // FallBack for single node @@ -337,10 +390,49 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, channelKey recvKey{(void*)recvBasePtr, recvBytes}; mscclpp::DeviceHandle* memoryChannels = nullptr; mscclpp::DeviceHandle* memoryOutChannels = nullptr; + mscclpp::DeviceHandle* nvlsChannels = nullptr; + mscclpp::DeviceHandle* nvlsOutChannels = nullptr; size_t bytes = count * ncclTypeSize(datatype); + bool useNvlsWithZeroCopy = mscclpp::isNvlsSupported() && !mscclppDisableChannelCache; + bool useNvlsWithCopy = mscclpp::isNvlsSupported() && mscclppDisableChannelCache; // Creating the channels - if (count * ncclTypeSize(datatype) <= (1 << 20)) { + if (useNvlsWithZeroCopy) { + auto nvlsIt = comm->channelNvlsInfos.find(sendKey); + if (nvlsIt == comm->channelNvlsInfos.end()) { + std::vector channels = + setupNvlsChannels(comm->nvlsConnections, (void*)sendBasePtr, sendBytes); + NvlsChannelInfo channelInfo{channels, setupNvlsChannelDeviceHandles(channels)}; + nvlsIt = comm->channelNvlsInfos.emplace(sendKey, channelInfo).first; + } + nvlsChannels = nvlsIt->second.nvlsChannelDeviceHandles.get(); + if (recvbuff != sendbuff) { + auto nvlsOutIt = comm->channelNvlsInfos.find(recvKey); + if (nvlsOutIt == comm->channelNvlsInfos.end()) { + std::vector channels = + setupNvlsChannels(comm->nvlsConnectionsOut, (void*)recvBasePtr, recvBytes); + NvlsChannelInfo channelInfo{channels, setupNvlsChannelDeviceHandles(channels)}; + nvlsOutIt = comm->channelNvlsInfos.emplace(recvKey, channelInfo).first; + } + nvlsOutChannels = nvlsOutIt->second.nvlsChannelDeviceHandles.get(); + } else { + nvlsOutChannels = nvlsChannels; + } + } + + if (useNvlsWithCopy) { + channelKey sendKey{(void*)(comm->scratchBuff.get()), SCRATCH_SIZE}; + auto nvlsIt = comm->channelNvlsInfos.find(sendKey); + if (nvlsIt == comm->channelNvlsInfos.end()) { + std::vector channels = + setupNvlsChannels(comm->nvlsConnections, (void*)comm->scratchBuff.get(), SCRATCH_SIZE); + NvlsChannelInfo channelInfo{channels, setupNvlsChannelDeviceHandles(channels)}; + nvlsIt = comm->channelNvlsInfos.emplace(sendKey, channelInfo).first; + } + nvlsChannels = nvlsIt->second.nvlsChannelDeviceHandles.get(); + } + + if (count * ncclTypeSize(datatype) <= (1 << 20) || mscclpp::isNvlsSupported()) { auto sendIt = comm->channelScratchInfos.find(sendKey); if (sendIt == comm->channelScratchInfos.end()) { std::vector channels = @@ -386,8 +478,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, Op reduceOp = getReduceOp(op); std::function*, - mscclpp::DeviceHandle*, size_t, size_t, size_t, int, int, int, - size_t, cudaStream_t, uint32_t*, uint32_t*, uint32_t*, int)> + mscclpp::DeviceHandle*, + mscclpp::DeviceHandle*, + mscclpp::DeviceHandle*, size_t, size_t, + size_t, int, int, int, size_t, cudaStream_t, uint32_t*, uint32_t*, uint32_t*, int)> allreduceFunc; if (reduceOp == SUM) { if (datatype == ncclFloat16) { @@ -416,11 +510,11 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff, return ncclInvalidArgument; } } - CUDACHECK(allreduceFunc(sendbuff, comm->scratchBuff.get(), recvbuff, memoryChannels, memoryOutChannels, offsetIn, - offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE, - comm->comm->bootstrap()->getNranks(), count, stream, (uint32_t*)comm->deviceFlag7.get(), - (uint32_t*)comm->deviceFlag28.get(), (uint32_t*)comm->deviceFlag56.get(), - comm->numScratchBuff)); + CUDACHECK(allreduceFunc(sendbuff, comm->scratchBuff.get(), recvbuff, memoryChannels, memoryOutChannels, nvlsChannels, + nvlsOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), + NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream, + (uint32_t*)comm->deviceFlag7.get(), (uint32_t*)comm->deviceFlag28.get(), + (uint32_t*)comm->deviceFlag56.get(), comm->numScratchBuff)); return ncclSuccess; } @@ -447,7 +541,7 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff, MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)sendbuff)); size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr; channelKey recvKey{(void*)recvBasePtr, recvBytes}; - [[maybe_unused]] channelKey sendKey{(void*)sendBasePtr, sendBytes}; + [[maybe_unused]] channelKey sendKey{(void*)comm->scratchBuff.get(), SCRATCH_SIZE}; int rank = comm->comm->bootstrap()->getRank(); int nRank = comm->comm->bootstrap()->getNranks(); mscclpp::DeviceHandle* memoryChannels = nullptr; @@ -533,6 +627,10 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt mscclppComm->setup(); commPtr->connections = std::move(connections); + if (mscclpp::isNvlsSupported()) { + commPtr->nvlsConnections = setupNvlsConnections(commPtr, NVLS_BUFFER_SIZE); + commPtr->nvlsConnectionsOut = setupNvlsConnections(commPtr, NVLS_BUFFER_SIZE); + } commPtr->memorySemaphores = std::move(memorySemaphores); commPtr->buffFlag = 0; commPtr->numScratchBuff = 2; diff --git a/include/mscclpp/nvls.hpp b/include/mscclpp/nvls.hpp index 90915cf7..25d5d7f1 100644 --- a/include/mscclpp/nvls.hpp +++ b/include/mscclpp/nvls.hpp @@ -30,7 +30,7 @@ class NvlsConnection { using DeviceHandle = DeviceMulticastPointerDeviceHandle; DeviceMulticastPointer(void* devicePtr, std::shared_ptr mcPtr, size_t bufferSize) : devicePtr_(devicePtr), mcPtr_(mcPtr), bufferSize_(bufferSize) {} - DeviceHandle deviceHandle(); + DeviceHandle deviceHandle() const; void* getDevicePtr(); friend class NvlsConnection; diff --git a/src/nvls.cc b/src/nvls.cc index c8cbd8cd..188028aa 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -198,6 +198,11 @@ std::shared_ptr NvlsConnection::Impl::bindMemory(CUdeviceptr devicePtr, si ErrorCode::InvalidUsage); } + if ((uintptr_t)devicePtr % minMcGran_ != 0) { + WARN("NVLS connection tried to bind a buffer that is not aligned to the minimum granularity"); + throw Error("This NVLS connection tried to bind a buffer that is not aligned to the minimum granularity", + ErrorCode::InvalidUsage); + } devBuffSize = ((devBuffSize + minMcGran_ - 1) / minMcGran_) * minMcGran_; size_t offset = allocateBuffer(devBuffSize); MSCCLPP_CUTHROW(cuMulticastBindAddr(mcHandle_, offset /*mcOffset*/, devicePtr, devBuffSize, 0)); @@ -265,7 +270,7 @@ NvlsConnection::DeviceMulticastPointer NvlsConnection::bindAllocatedMemory(CUdev return DeviceMulticastPointer((void*)devicePtr, mcPtr, size); } -NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() { +NvlsConnection::DeviceMulticastPointer::DeviceHandle NvlsConnection::DeviceMulticastPointer::deviceHandle() const { NvlsConnection::DeviceMulticastPointer::DeviceHandle device; device.devicePtr = this->devicePtr_; device.mcPtr = this->mcPtr_.get();