Optimize MNNVL allreduce without symmetric memory

Run the tuning example with symmetric memory disabled, make allreduce tuning use the same symmetric-memory mode as execution, and narrow the MNNVL small-message candidate set to avoid slower packet/NVLS choices. Increase packet and RSAG channel parallelism so non-symmetric CUDA-IPC paths can use 112-block packet and 128-block RSAG configs.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-04-28 07:55:52 +00:00
parent dded5e0e39
commit 865c2bc795
3 changed files with 30 additions and 19 deletions

View File

@@ -122,14 +122,14 @@ class CustomizedComm:
_TUNE_N_WARMUP = 5
_TUNE_N_GRAPH_LAUNCHES = 10
_TUNE_N_OPS_PER_GRAPH = 100
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 128]
_CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 112, 128]
_CANDIDATE_NTHREADS = [512, 768, 1024]
_NBLOCKS_LIMIT = {
"default_allreduce_nvls_packet": 16,
"default_allreduce_nvls_zero_copy": 32,
"default_allreduce_packet": 56,
"default_allreduce_packet": 112,
"default_allreduce_allpair_packet": 56,
"default_allreduce_rsag": 64,
"default_allreduce_rsag": 128,
"default_allreduce_rsag_zero_copy": 64,
"default_allreduce_fullmesh": 64,
"default_allgather_fullmesh2": 32,
@@ -162,6 +162,11 @@ class CustomizedComm:
def _algo(self, collective: str, name: str):
return self._algos.get((collective, name))
def _nblocks_limit(self, algo_name: str, size: int) -> int:
if algo_name == "default_allreduce_packet" and size < (1 << 20):
return 56
return self._NBLOCKS_LIMIT.get(algo_name, 128)
def _default_ar_config(self):
"""Fallback allreduce config for barrier / timing sync."""
pkt = self._algo("allreduce", "default_allreduce_nvls_packet")
@@ -218,7 +223,7 @@ class CustomizedComm:
def _barrier_internal(self):
a, nb, nt = self._default_ar_config()
self._exec_ar(self._barrier_tensor, a, nb, nt, sym=True)
self._exec_ar(self._barrier_tensor, a, nb, nt, sym=self.symmetric_memory)
# -- lazy tuning --
@@ -233,15 +238,17 @@ class CustomizedComm:
out = []
if self.multi_host_mnnvl:
if size <= 4 << 20:
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_allpair_packet")
if a:
out.append(a)
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
if size <= 64 << 10:
a = self._algo("allreduce", "default_allreduce_nvls_packet")
if self._nvls and a:
out.append(a)
if size > 128 << 10:
a = self._algo("allreduce", "default_allreduce_packet")
if a:
out.append(a)
if size >= 512 << 10:
a = self._algo("allreduce", "default_allreduce_rsag_zero_copy")
if self.symmetric_memory and a:
@@ -308,7 +315,7 @@ class CustomizedComm:
stream=torch.cuda.current_stream().cuda_stream,
nblocks=nb,
nthreads_per_block=nt,
symmetric_memory=True,
symmetric_memory=self.symmetric_memory,
)
else:
total = size * self.world_size
@@ -337,7 +344,7 @@ class CustomizedComm:
run = lambda a, nb, nt: self._run_tune(collective, a, buf, target_size, nb, nt)
for algo in cands:
nb_limit = self._NBLOCKS_LIMIT.get(algo.name, 128)
nb_limit = self._nblocks_limit(algo.name, target_size)
for nb in self._CANDIDATE_NBLOCKS:
if nb > nb_limit:
continue
@@ -346,7 +353,7 @@ class CustomizedComm:
ret = run(algo, nb, nt)
torch.cuda.synchronize()
self._time_buf[0] = float(ret)
self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=True)
self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=self.symmetric_memory)
if self._time_buf[0].item() != 0:
continue
used.add(algo)
@@ -375,7 +382,7 @@ class CustomizedComm:
# Cross-rank timing sync
self._time_buf.fill_(elapsed)
torch.cuda.current_stream().wait_stream(cs)
self._exec_ar(self._time_buf, *self._default_ar_config(), sym=True)
self._exec_ar(self._time_buf, *self._default_ar_config(), sym=self.symmetric_memory)
avg = self._time_buf[self.rank].item() / self.world_size
if avg < best_time:
@@ -575,7 +582,7 @@ def main():
n_iter = _get_env_int("MSCCLPP_BENCH_ITERS", default=100)
comm_group = init_dist()
cc = CustomizedComm(comm_group, symmetric_memory=True)
cc = CustomizedComm(comm_group, symmetric_memory=False)
print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...")
benchmark_allreduce(

View File

@@ -133,7 +133,7 @@ struct AllreduceRsAgAdapter {
size_t nelems = inputSize / sizeof(T);
if (nBlocks == 0 || nThreadsPerBlock == 0) {
nThreadsPerBlock = 1024;
nBlocks = 64;
nBlocks = 128;
}
allreduceRsAg<OpType, T><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank,
@@ -144,7 +144,7 @@ struct AllreduceRsAgAdapter {
void AllreduceRsAg::initialize(std::shared_ptr<Communicator> comm) {
this->conns_ = setupConnections(comm);
nChannelsPerConnection_ = 64;
nChannelsPerConnection_ = 128;
comm_ = comm;
// setup semaphores
this->scratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
@@ -179,6 +179,10 @@ CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr<void> ctx, c
return CommResult::CommInvalidArgument;
}
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
if (numBlocksAndThreads.first > nChannelsPerConnection_) {
WARN(ALGO, "Block number ", numBlocksAndThreads.first, " exceeds the maximum limit ", nChannelsPerConnection_);
return CommResult::CommInvalidArgument;
}
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(),
this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, 0, algoCtx->rank,
algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, nullptr, 0, 0,

View File

@@ -29,7 +29,7 @@ class AllreducePacket : public AlgorithmBuilder {
void* scratchBuffer_;
size_t scratchBufferSize_;
const int nSegmentsForScratchBuffer_ = 2;
const int maxBlockNum_ = 56;
const int maxBlockNum_ = 112;
std::vector<Connection> conns_;
uintptr_t flagBuffer_;
size_t flagBufferSize_;
@@ -37,4 +37,4 @@ class AllreducePacket : public AlgorithmBuilder {
std::vector<RegisteredMemory> registeredMemories_;
};
} // namespace collective
} // namespace mscclpp
} // namespace mscclpp