Fix collective topology sizing

Rename native collective context workSize to worldSize and use nRanksPerIpcDomain for allpair peer topology. Include the staged DSL signal/wait pairing validation changes.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-05-20 20:21:06 +00:00
18 changed files with 130 additions and 74 deletions

View File

@@ -79,7 +79,7 @@ __global__ void __launch_bounds__(1024)
struct Context {
int rank;
int workSize;
int worldSize;
int nRanksPerNode;
std::vector<mscclpp::RegisteredMemory> registeredMemories;
@@ -140,7 +140,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
size_t inputSize, cudaStream_t stream) {
auto algoCtx = std::static_pointer_cast<Context>(ctx);
int rank = algoCtx->rank;
int worldSize = algoCtx->workSize;
int worldSize = algoCtx->worldSize;
int nThreadsPerBlock = (worldSize - 1) * WARP_SIZE;
allgather<<<1, nThreadsPerBlock, 0, stream>>>(algoCtx->portChannelDeviceHandles.get(), rank, inputSize);
@@ -154,16 +154,16 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
void* output, size_t inputSize, mscclpp::DataType dtype) {
auto ctx = std::make_shared<Context>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
// register memories
mscclpp::RegisteredMemory inputBufRegMem =
comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc);
mscclpp::RegisteredMemory outputBufRegMem =
comm->registerMemory(output, inputSize * ctx->workSize, mscclpp::Transport::CudaIpc);
comm->registerMemory(output, inputSize * ctx->worldSize, mscclpp::Transport::CudaIpc);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
for (int i = 0; i < ctx->workSize; i++) {
for (int i = 0; i < ctx->worldSize; i++) {
if (i == ctx->rank) continue;
comm->sendMemory(outputBufRegMem, i, 0);
remoteRegMemories.push_back(comm->recvMemory(i, 0));