mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-11 08:50:21 +00:00
Tune no-sym MNNVL with RSAG zero-copy
Disable NVLS zero-copy when symmetric memory is not enabled, and allow the RSAG zero-copy path to participate in MNNVL tuning for non-symmetric memory. Cache RSAG zero-copy contexts by the concrete buffer pointers so CUDA graph capture does not create a new registration for every execute call, and cap requested blocks at the channel count. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -130,7 +130,7 @@ class CustomizedComm:
|
||||
"default_allreduce_packet": 112,
|
||||
"default_allreduce_allpair_packet": 56,
|
||||
"default_allreduce_rsag": 128,
|
||||
"default_allreduce_rsag_zero_copy": 64,
|
||||
"default_allreduce_rsag_zero_copy": 128,
|
||||
"default_allreduce_fullmesh": 64,
|
||||
"default_allgather_fullmesh2": 32,
|
||||
}
|
||||
@@ -252,10 +252,10 @@ class CustomizedComm:
|
||||
out.append(a)
|
||||
if size >= 512 << 10:
|
||||
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
|
||||
if self.symmetric_memory and a:
|
||||
if a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_nvls_zero_copy")
|
||||
if self._nvls and a:
|
||||
if self._nvls and self.symmetric_memory and a:
|
||||
out.append(a)
|
||||
a = self._algo("allreduce", "default_allreduce_rsag")
|
||||
if a:
|
||||
|
||||
@@ -122,6 +122,10 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_vo
|
||||
cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
[[maybe_unused]] const std::unordered_map<std::string, uintptr_t>& extras,
|
||||
mscclpp::DataType accumDtype) {
|
||||
if (!symmetricMemory_) {
|
||||
WARN("AllreduceNvls requires symmetric memory for now.");
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
|
||||
AllreduceFunc allreduce = dispatch<NvlsAdapter>(op, dtype, accumDtype);
|
||||
if (!allreduce) {
|
||||
@@ -165,7 +169,8 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_vo
|
||||
}
|
||||
|
||||
mscclpp::AlgorithmCtxKey AllreduceNvls::generateAllreduceContextKey(const void* input, void* output, size_t,
|
||||
mscclpp::DataType, bool) {
|
||||
mscclpp::DataType, bool symmetricMemory) {
|
||||
symmetricMemory_ = symmetricMemory;
|
||||
size_t sendBytes, recvBytes;
|
||||
CUdeviceptr sendBasePtr, recvBasePtr;
|
||||
MSCCLPP_CUTHROW(cuMemGetAddressRange(&sendBasePtr, &sendBytes, (CUdeviceptr)input));
|
||||
|
||||
@@ -153,6 +153,10 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
|
||||
if (numBlocksAndThreads.first > nChannelsPerConnection_) {
|
||||
WARN(ALGO, "Block number ", numBlocksAndThreads.first, " exceeds the maximum limit ", nChannelsPerConnection_);
|
||||
return CommResult::CommInvalidArgument;
|
||||
}
|
||||
cudaError_t error =
|
||||
allreduce(input, nullptr, output, this->baseMemoryChannelHandles_.get(), algoCtx->remoteMemoryHandles.get(),
|
||||
nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream,
|
||||
@@ -165,9 +169,8 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void
|
||||
}
|
||||
|
||||
AlgorithmCtxKey AllreduceRsAgZeroCopy::generateAllreduceContextKey(const void* inputBuffer, void* outputBuffer,
|
||||
size_t size, DataType, bool symmetricMemory) {
|
||||
size_t size, DataType, bool symmetricMemory) {
|
||||
// For non-symmetric algorithms, we use both input and output buffer pointers in the key.
|
||||
static int tag = 0;
|
||||
if (symmetricMemory) {
|
||||
size_t inputBytes, outputBytes;
|
||||
CUdeviceptr inputBasePtr, outputBasePtr;
|
||||
@@ -175,7 +178,7 @@ AlgorithmCtxKey AllreduceRsAgZeroCopy::generateAllreduceContextKey(const void* i
|
||||
MSCCLPP_CUTHROW(cuMemGetAddressRange(&outputBasePtr, &outputBytes, (CUdeviceptr)outputBuffer));
|
||||
return AlgorithmCtxKey{(void*)inputBasePtr, (void*)outputBasePtr, inputBytes, outputBytes, 0};
|
||||
}
|
||||
return AlgorithmCtxKey{(void*)inputBuffer, outputBuffer, size, size, ++tag};
|
||||
return AlgorithmCtxKey{(void*)inputBuffer, outputBuffer, size, size, 0};
|
||||
}
|
||||
|
||||
std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_ptr<Communicator> comm, const void* input,
|
||||
|
||||
@@ -15,6 +15,7 @@ class AllreduceNvls : public AlgorithmBuilder {
|
||||
std::shared_ptr<Algorithm> build() override;
|
||||
|
||||
private:
|
||||
bool symmetricMemory_ = false;
|
||||
void initialize(std::shared_ptr<Communicator> comm);
|
||||
CommResult allreduceKernelFunc(const std::shared_ptr<void> ctx, const void* input, void* output, size_t inputSize,
|
||||
DataType dtype, ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock,
|
||||
|
||||
Reference in New Issue
Block a user