mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-25 07:14:40 +00:00
Fix collective topology sizing
Rename native collective context workSize to worldSize and use nRanksPerIpcDomain for allpair peer topology. Include the staged DSL signal/wait pairing validation changes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -79,7 +79,7 @@ __global__ void __launch_bounds__(1024)
|
||||
|
||||
struct Context {
|
||||
int rank;
|
||||
int workSize;
|
||||
int worldSize;
|
||||
int nRanksPerNode;
|
||||
|
||||
std::vector<mscclpp::RegisteredMemory> registeredMemories;
|
||||
@@ -140,7 +140,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
size_t inputSize, cudaStream_t stream) {
|
||||
auto algoCtx = std::static_pointer_cast<Context>(ctx);
|
||||
int rank = algoCtx->rank;
|
||||
int worldSize = algoCtx->workSize;
|
||||
int worldSize = algoCtx->worldSize;
|
||||
|
||||
int nThreadsPerBlock = (worldSize - 1) * WARP_SIZE;
|
||||
allgather<<<1, nThreadsPerBlock, 0, stream>>>(algoCtx->portChannelDeviceHandles.get(), rank, inputSize);
|
||||
@@ -154,16 +154,16 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
|
||||
void* output, size_t inputSize, mscclpp::DataType dtype) {
|
||||
auto ctx = std::make_shared<Context>();
|
||||
ctx->rank = comm->bootstrap()->getRank();
|
||||
ctx->workSize = comm->bootstrap()->getNranks();
|
||||
ctx->worldSize = comm->bootstrap()->getNranks();
|
||||
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
|
||||
|
||||
// register memories
|
||||
mscclpp::RegisteredMemory inputBufRegMem =
|
||||
comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc);
|
||||
mscclpp::RegisteredMemory outputBufRegMem =
|
||||
comm->registerMemory(output, inputSize * ctx->workSize, mscclpp::Transport::CudaIpc);
|
||||
comm->registerMemory(output, inputSize * ctx->worldSize, mscclpp::Transport::CudaIpc);
|
||||
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
|
||||
for (int i = 0; i < ctx->workSize; i++) {
|
||||
for (int i = 0; i < ctx->worldSize; i++) {
|
||||
if (i == ctx->rank) continue;
|
||||
comm->sendMemory(outputBufRegMem, i, 0);
|
||||
remoteRegMemories.push_back(comm->recvMemory(i, 0));
|
||||
|
||||
Reference in New Issue
Block a user