Tune no-sym MNNVL with RSAG zero-copy

Disable NVLS zero-copy when symmetric memory is not enabled, and allow the RSAG zero-copy path to participate in MNNVL tuning for non-symmetric memory. Cache RSAG zero-copy contexts by the concrete buffer pointers so CUDA graph capture does not create a new registration for every execute call, and cap requested blocks at the channel count.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-04-28 16:23:23 +00:00
parent 3bc00cb7f0
commit 533f329971
4 changed files with 16 additions and 7 deletions

View File

@@ -130,7 +130,7 @@ class CustomizedComm:
"default_allreduce_packet": 112,
"default_allreduce_allpair_packet": 56,
"default_allreduce_rsag": 128,
"default_allreduce_rsag_zero_copy": 64,
"default_allreduce_rsag_zero_copy": 128,
"default_allreduce_fullmesh": 64,
"default_allgather_fullmesh2": 32,
}
@@ -252,10 +252,10 @@ class CustomizedComm:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
if self.symmetric_memory and a:
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
if self._nvls and a:
if self._nvls and self.symmetric_memory and a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_rsag")
if a:

View File

@@ -122,6 +122,10 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_vo
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras,
mscclpp::DataType accumDtype) {
if (!symmetricMemory_) {
WARN("AllreduceNvls requires symmetric memory for now.");
return CommResult::CommInvalidArgument;
}
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
AllreduceFunc allreduce = dispatch<NvlsAdapter>(op, dtype, accumDtype);
if (!allreduce) {
@@ -165,7 +169,8 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_vo
}
mscclpp::AlgorithmCtxKey AllreduceNvls::generateAllreduceContextKey(const void* input, void* output, size_t,
mscclpp::DataType, bool) {
mscclpp::DataType, bool symmetricMemory) {
symmetricMemory_ = symmetricMemory;
size_t sendBytes, recvBytes;
CUdeviceptr sendBasePtr, recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input));

View File

@@ -153,6 +153,10 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void
return CommResult::CommInvalidArgument;
}
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
if (numBlocksAndThreads.first > nChannelsPerConnection_) {
WARN(ALGO, "Block number ", numBlocksAndThreads.first, " exceeds the maximum limit ", nChannelsPerConnection_);
return CommResult::CommInvalidArgument;
}
cudaError_t error =
allreduce(input, nullptr, output, this->baseMemoryChannelHandles_.get(), algoCtx->remoteMemoryHandles.get(),
nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream,
@@ -165,9 +169,8 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void
}
AlgorithmCtxKey AllreduceRsAgZeroCopy::generateAllreduceContextKey(const void* inputBuffer, void* outputBuffer,
size_t size, DataType, bool symmetricMemory) {
size_t size, DataType, bool symmetricMemory) {
// For non-symmetric algorithms, we use both input and output buffer pointers in the key.
static int tag = 0;
if (symmetricMemory) {
size_t inputBytes, outputBytes;
CUdeviceptr inputBasePtr, outputBasePtr;
@@ -175,7 +178,7 @@ AlgorithmCtxKey AllreduceRsAgZeroCopy::generateAllreduceContextKey(const void* i
MSCCLPP_CUTHROW(cuMemGetAddressRange(&outputBasePtr, &outputBytes, (CUdeviceptr)outputBuffer));
return AlgorithmCtxKey{(void*)inputBasePtr, (void*)outputBasePtr, inputBytes, outputBytes, 0};
}
return AlgorithmCtxKey{(void*)inputBuffer, outputBuffer, size, size, ++tag};
return AlgorithmCtxKey{(void*)inputBuffer, outputBuffer, size, size, 0};
}
std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_ptr<Communicator> comm, const void* input,

View File

@@ -15,6 +15,7 @@ class AllreduceNvls : public AlgorithmBuilder {
std::shared_ptr<Algorithm> build() override;
private:
bool symmetricMemory_ = false;
void initialize(std::shared_ptr<Communicator> comm);
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,