From 533f329971e003e2ca67803c19959d13bf7140ea Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Tue, 28 Apr 2026 16:23:23 +0000 Subject: [PATCH] Tune no-sym MNNVL with RSAG zero-copy Disable NVLS zero-copy when symmetric memory is not enabled, and allow the RSAG zero-copy path to participate in MNNVL tuning for non-symmetric memory. Cache RSAG zero-copy contexts by the concrete buffer pointers so CUDA graph capture does not create a new registration for every execute call, and cap requested blocks at the channel count. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../torch-integration/customized_comm_with_tuning.py | 6 +++--- .../collectives/allreduce/allreduce_nvls_zero_copy.cu | 7 ++++++- .../collectives/allreduce/allreduce_rsag_zero_copy.cu | 9 ++++++--- .../include/allreduce/allreduce_nvls_zero_copy.hpp | 1 + 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 0736cb68..6f8f097d 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -130,7 +130,7 @@ class CustomizedComm: "default_allreduce_packet": 112, "default_allreduce_allpair_packet": 56, "default_allreduce_rsag": 128, - "default_allreduce_rsag_zero_copy": 64, + "default_allreduce_rsag_zero_copy": 128, "default_allreduce_fullmesh": 64, "default_allgather_fullmesh2": 32, } @@ -252,10 +252,10 @@ class CustomizedComm: out.append(a) if size >= 512 << 10: a = self._algo("allreduce", "default_allreduce_rsag_zero_copy") - if self.symmetric_memory and a: + if a: out.append(a) a = self._algo("allreduce", "default_allreduce_nvls_zero_copy") - if self._nvls and a: + if self._nvls and self.symmetric_memory and a: out.append(a) a = self._algo("allreduce", "default_allreduce_rsag") if a: diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index 8c360f96..25077004 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -122,6 +122,10 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr ctx_vo cudaStream_t stream, int nBlocks, int nThreadsPerBlock, [[maybe_unused]] const std::unordered_map& extras, mscclpp::DataType accumDtype) { + if (!symmetricMemory_) { + WARN("AllreduceNvls requires symmetric memory for now."); + return CommResult::CommInvalidArgument; + } auto ctx = std::static_pointer_cast(ctx_void); AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { @@ -165,7 +169,8 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr ctx_vo } mscclpp::AlgorithmCtxKey AllreduceNvls::generateAllreduceContextKey(const void* input, void* output, size_t, - mscclpp::DataType, bool) { + mscclpp::DataType, bool symmetricMemory) { + symmetricMemory_ = symmetricMemory; size_t sendBytes, recvBytes; CUdeviceptr sendBasePtr, recvBasePtr; MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input)); diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index c4dea321..a11da0f8 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -153,6 +153,10 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; + if (numBlocksAndThreads.first > nChannelsPerConnection_) { + WARN(ALGO, "Block number ", numBlocksAndThreads.first, " exceeds the maximum limit ", nChannelsPerConnection_); + return CommResult::CommInvalidArgument; + } cudaError_t error = allreduce(input, nullptr, output, this->baseMemoryChannelHandles_.get(), algoCtx->remoteMemoryHandles.get(), nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, @@ -165,9 +169,8 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_ptr comm, const void* input, diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp index 39615280..c40bd2cd 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp @@ -15,6 +15,7 @@ class AllreduceNvls : public AlgorithmBuilder { std::shared_ptr build() override; private: + bool symmetricMemory_ = false; void initialize(std::shared_ptr comm); CommResult allreduceKernelFunc(const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,