Fix collective topology sizing

Rename native collective context workSize to worldSize and use nRanksPerIpcDomain for allpair peer topology. Include the staged DSL signal/wait pairing validation changes.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Binyang Li
2026-05-20 20:21:06 +00:00
18 changed files with 130 additions and 74 deletions

View File

@@ -79,7 +79,7 @@ __global__ void __launch_bounds__(1024)
struct Context {
int rank;
int workSize;
int worldSize;
int nRanksPerNode;
std::vector<mscclpp::RegisteredMemory> registeredMemories;
@@ -140,7 +140,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
size_t inputSize, cudaStream_t stream) {
auto algoCtx = std::static_pointer_cast<Context>(ctx);
int rank = algoCtx->rank;
int worldSize = algoCtx->workSize;
int worldSize = algoCtx->worldSize;
int nThreadsPerBlock = (worldSize - 1) * WARP_SIZE;
allgather<<<1, nThreadsPerBlock, 0, stream>>>(algoCtx->portChannelDeviceHandles.get(), rank, inputSize);
@@ -154,16 +154,16 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
void* output, size_t inputSize, mscclpp::DataType dtype) {
auto ctx = std::make_shared<Context>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
// register memories
mscclpp::RegisteredMemory inputBufRegMem =
comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc);
mscclpp::RegisteredMemory outputBufRegMem =
comm->registerMemory(output, inputSize * ctx->workSize, mscclpp::Transport::CudaIpc);
comm->registerMemory(output, inputSize * ctx->worldSize, mscclpp::Transport::CudaIpc);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
for (int i = 0; i < ctx->workSize; i++) {
for (int i = 0; i < ctx->worldSize; i++) {
if (i == ctx->rank) continue;
comm->sendMemory(outputBufRegMem, i, 0);
remoteRegMemories.push_back(comm->recvMemory(i, 0));

View File

@@ -47,7 +47,7 @@ __global__ void __launch_bounds__(1024)
struct Context {
int rank;
int workSize;
int worldSize;
int nRanksPerNode;
std::vector<mscclpp::RegisteredMemory> registeredMemories;
@@ -108,7 +108,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
cudaStream_t stream) {
auto algoCtx = std::static_pointer_cast<Context>(ctx);
int rank = algoCtx->rank;
int worldSize = algoCtx->workSize;
int worldSize = algoCtx->worldSize;
int nThreadsPerBlock = (worldSize - 1) * WARP_SIZE;
allgather<<<1, nThreadsPerBlock, 0, stream>>>(algoCtx->portChannelDeviceHandles.get(), rank, inputBytes);
@@ -122,16 +122,16 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder {
void* output, size_t inputBytes, mscclpp::DataType dtype) {
auto ctx = std::make_shared<Context>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
// register memories
mscclpp::RegisteredMemory inputBufRegMem =
comm->registerMemory((void*)input, inputBytes, mscclpp::Transport::CudaIpc);
mscclpp::RegisteredMemory outputBufRegMem =
comm->registerMemory(output, inputBytes * ctx->workSize, mscclpp::Transport::CudaIpc);
comm->registerMemory(output, inputBytes * ctx->worldSize, mscclpp::Transport::CudaIpc);
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemories;
for (int i = 0; i < ctx->workSize; i++) {
for (int i = 0; i < ctx->worldSize; i++) {
if (i == ctx->rank) continue;
comm->sendMemory(outputBufRegMem, i, 0);
remoteRegMemories.push_back(comm->recvMemory(i, 0));

View File

@@ -78,6 +78,7 @@ class MemoryChannel:
tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False):
"""Wait for a signal through the memory channel.
@@ -99,6 +100,7 @@ class MemoryChannel:
tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type)
def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None):
"""Retrieve data from remote memory to local memory.
@@ -508,6 +510,7 @@ class PortChannel:
tb_channel_ids = get_program().setup_channel(tb, self)
op = SignalOperation(tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
def wait(self, tb: int, data_sync: SyncType = SyncType.both):
"""Wait for a signal through the port channel.
@@ -527,6 +530,7 @@ class PortChannel:
tb_channel_ids = get_program().setup_channel(tb, self)
op = WaitOperation(tb_channel_ids, self.channel_type, data_sync)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type)
def flush(self, tb: int, data_sync: SyncType = SyncType.both):
"""Flush pending operations through the port channel.
@@ -636,6 +640,7 @@ class PortChannel:
)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
"""Send data from local memory to remote memory with signal and flush.
@@ -681,6 +686,7 @@ class PortChannel:
)
get_program().add_operation(self.src_rank, tb, op)
get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type)
def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
"""Transfer data from local buffer to remote scratch buffer in packet format.

View File

@@ -10,6 +10,7 @@ from mscclpp.language.rank import Semaphore
from mscclpp.language.collectives import *
from mscclpp.language.utils import AlgoSpec, ReplicationPolicy
from typing import List
from collections import defaultdict
import json
@@ -112,6 +113,9 @@ class CollectiveProgram:
self.loop_context = None
self._signal_counts = defaultdict(int)
self._wait_counts = defaultdict(int)
@classmethod
def from_spec(cls, spec: AlgoSpec):
"""Initialize a new CollectiveProgram from an algorithm specification.
@@ -206,7 +210,35 @@ class CollectiveProgram:
else:
self.gpus[rank].add_operation(tb, operation)
def register_signal(self, src_rank, dst_rank, channel_type):
"""Record that `src_rank` issued a signal targeting `dst_rank` over `channel_type`."""
self._signal_counts[(src_rank, dst_rank, channel_type)] += 1
def register_wait(self, src_rank, dst_rank, channel_type):
"""Record that `src_rank` performed a wait for `dst_rank` over `channel_type`."""
self._wait_counts[(src_rank, dst_rank, channel_type)] += 1
def validate_signal_wait_pairing(self):
"""Validate that every signal issued by a rank is matched by a wait on the peer rank.
For each (src_rank, dst_rank, channel_type) triple, the number of signals sent by
`src_rank` to `dst_rank` must equal the number of waits performed by `dst_rank`
for `src_rank` on a channel of the same type. Raises RuntimeError on mismatch.
"""
keys = set(self._signal_counts.keys()) | {(dst, src, t) for (src, dst, t) in self._wait_counts.keys()}
for src_rank, dst_rank, channel_type in keys:
signals = self._signal_counts.get((src_rank, dst_rank, channel_type), 0)
waits = self._wait_counts.get((dst_rank, src_rank, channel_type), 0)
if signals != waits:
raise RuntimeError(
f"Signal/Wait mismatch on {channel_type}: rank {src_rank} issues {signals} "
f"signal(s) to rank {dst_rank}, but rank {dst_rank} performs {waits} wait(s) "
f"for rank {src_rank}. Every signal must be matched by a corresponding wait "
f"on the peer rank over a channel of the same type."
)
def post_process_operations(self):
self.validate_signal_wait_pairing()
for gpu in self.gpus:
if self.instr_fusion:
gpu.optimize_operations()

View File

@@ -127,11 +127,11 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr<void> ct
if ((char*)input == (char*)output + rank * inputSize) {
allgatherFullmesh<false><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0, stream>>>(
(void*)input, this->scratchBuffer_, (void*)output, ctx->memoryChannelDeviceHandles.get(), rank,
ctx->nRanksPerIpcDomain, ctx->workSize, nElem);
ctx->nRanksPerIpcDomain, ctx->worldSize, nElem);
} else {
allgatherFullmesh<true><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0, stream>>>(
(void*)input, this->scratchBuffer_, (void*)output, ctx->memoryChannelDeviceHandles.get(), rank,
ctx->nRanksPerIpcDomain, ctx->workSize, nElem);
ctx->nRanksPerIpcDomain, ctx->worldSize, nElem);
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
@@ -147,7 +147,7 @@ std::shared_ptr<void> AllgatherFullmesh::initAllgatherContext(std::shared_ptr<Co
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
// setup semaphores

View File

@@ -139,11 +139,11 @@ CommResult AllgatherFullmesh2::allgatherKernelFunc(const std::shared_ptr<void> c
size_t channelOutOffset = *static_cast<size_t*>(ctx->extras["channel_out_offset"].get());
if ((char*)input == (char*)output + rank * inputSize) {
allgatherFullmesh2<false><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0, stream>>>(
(void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->workSize,
(void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->worldSize,
ctx->nRanksPerIpcDomain, nElem);
} else {
allgatherFullmesh2<true><<<numBlocksAndThreads.first, numBlocksAndThreads.second, 0, stream>>>(
(void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->workSize,
(void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->worldSize,
ctx->nRanksPerIpcDomain, nElem);
}
cudaError_t err = cudaGetLastError();
@@ -158,7 +158,7 @@ std::shared_ptr<void> AllgatherFullmesh2::initAllgatherContext(std::shared_ptr<m
void* output, size_t inputSize, mscclpp::DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
// setup semaphores

View File

@@ -7,7 +7,7 @@
#include "allreduce/allreduce_allpair_packet.hpp"
#include "allreduce/common.hpp"
#include "collective_utils.hpp"
#include "debug.h"
#include "logger.hpp"
namespace mscclpp {
namespace collective {
@@ -24,22 +24,30 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
size_t scratchBaseOffset = (flag % numScratchBuff) ? (scratchBufferSize / numScratchBuff) : 0;
size_t channelScratchOffset = scratchBaseOffset;
const int nBlocksPerPeer = gridDim.x / nPeers;
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
size_t srcOffset = channelDataOffset;
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
size_t scratchOffset = channelScratchOffset + rank * nelems * sizeof(LL8Packet);
void* scratchBuff = (void*)((char*)scratch + channelScratchOffset);
uint32_t* src = (uint32_t*)((char*)buff);
uint32_t* dst = (uint32_t*)((char*)resultBuff);
// step 1: write data to each peer's scratch buffer
memoryChannels[peerIdx].putPackets<LL8Packet>(scratchOffset, srcOffset, nelems * sizeof(uint32_t), tid,
blockDim.x * nBlocksPerPeer, flag);
const int warpId = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int nWarpsPerBlock = blockDim.x / WARP_SIZE;
// Assign one warp in every block to each peer. Each peer warp sends the
// same block-owned stripe, so nBlocks only partitions data and no longer
// needs to be grouped by nPeers.
if (warpId < nPeers) {
memoryChannels[warpId].putPackets<LL8Packet>(scratchOffset, channelDataOffset, nelems * sizeof(uint32_t),
lane + blockIdx.x * WARP_SIZE, gridDim.x * WARP_SIZE, flag);
}
// Safe for in-place allreduce: all peer warps must finish reading src for
// this block's stripe before any warp writes reduced data back to dst/src.
__syncthreads();
// step 2: Reduce Data
for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) {
// Split the same sent stream across all warps for reduction. warpId selects
// which strided subset to reduce while lane preserves coalesced packet reads.
for (size_t idx = lane + blockIdx.x * WARP_SIZE + warpId * WARP_SIZE * gridDim.x; idx < nelems;
idx += nWarpsPerBlock * WARP_SIZE * gridDim.x) {
uint32_t data = src[idx];
using AccRaw = std::conditional_t<std::is_same_v<T, AccumT>, uint32_t,
mscclpp::VectorType<AccumT, sizeof(uint32_t) / sizeof(T)>>;
@@ -56,16 +64,16 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand
if (threadIdx.x == 0) {
((uint32_t*)flags)[blockIdx.x] = flag + 1;
}
if (blockIdx.x == 0 && threadIdx.x >= gridDim.x && threadIdx.x < flagSize / sizeof(uint32_t)) {
((uint32_t*)flags)[threadIdx.x] = flag + 1;
if (tid >= gridDim.x && tid < flagSize / sizeof(uint32_t)) {
((uint32_t*)flags)[tid] = flag + 1;
}
}
inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int worldSize) {
if (inputSize < worldSize * sizeof(int)) {
return {worldSize - 1, 32};
inline std::pair<int, int> getDefaultBlockNumAndThreadNum(size_t inputSize, int nRanksPerIpcDomain) {
if (inputSize < nRanksPerIpcDomain * sizeof(int)) {
return {nRanksPerIpcDomain - 1, (nRanksPerIpcDomain - 1) * WARP_SIZE};
}
return {(worldSize - 1) * 4, 512};
return {(nRanksPerIpcDomain - 1) * 4, 512};
}
template <ReduceOp OpType, typename T, typename AccumT = T>
@@ -77,9 +85,6 @@ struct AllpairAdapter {
int nThreadsPerBlock = 0) {
using ChannelType = DeviceHandle<MemoryChannel>;
const size_t nelems = inputSize / sizeof(T);
// Round nBlocks to multiple of nPeers so every block maps to a valid peer.
const int nPeers = nRanksPerIpcDomain - 1;
nBlocks = (nBlocks / nPeers) * nPeers;
allreduceAllPairs<OpType, T, AccumT><<<nBlocks, nThreadsPerBlock, 0, stream>>>(
(T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank,
nRanksPerIpcDomain, worldSize, nelems, numScratchBuff, flags, flagSize);
@@ -101,18 +106,27 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<voi
const std::unordered_map<std::string, uintptr_t>&,
DataType accumDtype) {
auto algoCtx = std::static_pointer_cast<AlgorithmCtx>(ctx);
if (algoCtx->workSize != algoCtx->nRanksPerIpcDomain) {
WARN("AllreduceAllpairPacket requires workSize to match nRanksPerIpcDomain, got workSize=%d, nRanksPerIpcDomain=%d",
algoCtx->workSize, algoCtx->nRanksPerIpcDomain);
if (algoCtx->worldSize != algoCtx->nRanksPerIpcDomain) {
WARN(ALGO,
"AllreduceAllpairPacket requires worldSize to match nRanksPerIpcDomain, got worldSize=", algoCtx->worldSize,
", nRanksPerIpcDomain=", algoCtx->nRanksPerIpcDomain);
return CommResult::CommInvalidArgument;
}
std::pair<int, int> blockAndThreadNum{nBlocks, nThreadsPerBlock};
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, algoCtx->nRanksPerIpcDomain);
}
// nBlocks must be at least nPeers for allpair — each block maps to one peer.
if (blockAndThreadNum.first > maxBlockNum_) {
WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ",
maxBlockNum_, ".");
return CommResult::CommInvalidArgument;
}
const int nPeers = algoCtx->nRanksPerIpcDomain - 1;
if (blockAndThreadNum.first < nPeers) {
// The kernel maps peer sends by warpId, so every peer needs a full warp.
if (blockAndThreadNum.second % WARP_SIZE != 0 || blockAndThreadNum.second / WARP_SIZE < nPeers) {
WARN(ALGO,
"Allpair packet requires at least one full warp per peer, but got nThreadsPerBlock=", blockAndThreadNum.second,
" and nPeers=", nPeers, ".");
return CommResult::CommInvalidArgument;
}
size_t sendBytes;
@@ -122,16 +136,17 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr<voi
AllreduceFunc allreduce = dispatch<AllpairAdapter>(op, dtype, accumDtype);
if (!allreduce) {
WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast<int>(dtype));
WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast<int>(op),
", dtype=", static_cast<int>(dtype));
return CommResult::CommInvalidArgument;
}
cudaError_t error =
allreduce(input, this->scratchBuffer_, output, algoCtx->memoryChannelDeviceHandles.get(), nullptr, nullptr,
nullptr, channelInOffset, 0, this->scratchBufferSize_, algoCtx->rank, algoCtx->nRanksPerIpcDomain,
algoCtx->workSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_,
algoCtx->worldSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_,
this->nSegmentsForScratchBuffer_, blockAndThreadNum.first, blockAndThreadNum.second);
if (error != cudaSuccess) {
WARN("AllreducePacket failed with error: %s", cudaGetErrorString(error));
WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error));
return CommResult::CommUnhandledCudaError;
}
return CommResult::CommSuccess;
@@ -142,7 +157,7 @@ std::shared_ptr<void> AllreduceAllpairPacket::initAllreduceContext(std::shared_p
auto ctx = std::make_shared<AlgorithmCtx>();
const int nChannelsPerConnection = maxBlockNum_;
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
ctx->memorySemaphores = this->memorySemaphores_;
ctx->registeredMemories = this->registeredMemories_;

View File

@@ -223,7 +223,7 @@ CommResult AllreduceFullmesh::allreduceKernelFunc(
}
cudaError_t error =
allreduce(input, this->scratchBuffer_, output, inputChannelHandles.get(), ctx->memoryChannelDeviceHandles.get(),
nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize,
nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize,
stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second);
if (error != cudaSuccess) {
WARN("AllreduceAllconnect failed with error: %s", cudaGetErrorString(error));
@@ -249,7 +249,7 @@ std::shared_ptr<void> AllreduceFullmesh::initAllreduceContext(std::shared_ptr<Co
void* output, size_t size, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
// setup semaphores

View File

@@ -205,7 +205,7 @@ CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(
}
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->memoryChannelsDeviceHandle_.get(), nullptr,
ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_,
ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize, stream, nullptr, 0, 0,
ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0,
blockAndThreadNum.first, blockAndThreadNum.second);
if (error != cudaSuccess) {
WARN("AllreduceNvlsBlockPipeline failed with error: %s", cudaGetErrorString(error));
@@ -222,7 +222,7 @@ std::shared_ptr<void> AllreduceNvlsBlockPipeline::initAllreduceContext(std::shar
void*, size_t, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
// setup channels

View File

@@ -93,7 +93,7 @@ std::shared_ptr<void> AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr<
size_t, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
// setup channels
@@ -123,7 +123,7 @@ CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr<void>
}
cudaError_t error =
allreduce(input, this->scratchBuffer_, output, nullptr, nullptr, ctx->switchChannelDeviceHandles.get(), nullptr,
0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize, stream,
0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream,
(void*)flagBuffer_, (uint32_t)flagBufferSize_, 0, blockAndThreadNum.first, blockAndThreadNum.second);
if (error != cudaSuccess) {
WARN(ALGO, "AllreduceNvlsPacket failed with error: ", cudaGetErrorString(error));

View File

@@ -169,7 +169,7 @@ CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc(
}
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->memoryChannelsDeviceHandle_.get(), nullptr,
ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_,
ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize, stream, nullptr, 0, 0,
ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0,
blockAndThreadNum.first, blockAndThreadNum.second);
if (error != cudaSuccess) {
WARN("AllreduceNvlsWarpPipeline failed with error: %s", cudaGetErrorString(error));
@@ -186,7 +186,7 @@ std::shared_ptr<void> AllreduceNvlsWarpPipeline::initAllreduceContext(std::share
void*, size_t, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
// setup channels

View File

@@ -149,17 +149,17 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr<void> ctx_vo
// the number of GPUs. Empirically, 32 blocks works well for 4 GPUs and 16 for 8 GPUs, which
// follows the formula 128 / nGPUs, clamped to [1, MAX_NBLOCKS].
if (computeCapabilityMajor_ == 10) {
numBlocksAndThreads.first = ::max(1, ::min(128 / ctx->workSize, MAX_NBLOCKS));
numBlocksAndThreads.first = ::max(1, ::min(128 / ctx->worldSize, MAX_NBLOCKS));
}
}
if (numBlocksAndThreads.first > MAX_NBLOCKS) {
WARN("Number of blocks exceeds maximum supported value of %d", MAX_NBLOCKS);
return CommResult::CommInvalidArgument;
}
cudaError_t error =
allreduce(nullptr, nullptr, nullptr, this->memoryChannelsDeviceHandle_.get(), nullptr, nvlsChannels,
nvlsOutChannels, channelInOffset, channelOutOffset, 0, ctx->rank, ctx->nRanksPerIpcDomain,
ctx->workSize, inputSize, stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second);
cudaError_t error = allreduce(nullptr, nullptr, nullptr, this->memoryChannelsDeviceHandle_.get(), nullptr,
nvlsChannels, nvlsOutChannels, channelInOffset, channelOutOffset, 0, ctx->rank,
ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0,
numBlocksAndThreads.first, numBlocksAndThreads.second);
if (error != cudaSuccess) {
if (error == cudaErrorNotSupported) {
WARN("AllreduceNvls does not support the requested data type.");
@@ -185,7 +185,7 @@ std::shared_ptr<void> AllreduceNvls::initAllreduceContext(std::shared_ptr<mscclp
const void* input, void* output, size_t, mscclpp::DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
size_t sendBytes, recvBytes;

View File

@@ -230,20 +230,25 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_
const std::unordered_map<std::string, uintptr_t>&,
DataType accumDtype) {
auto ctx = std::static_pointer_cast<AlgorithmCtx>(ctx_void);
if (ctx->workSize != ctx->nRanksPerIpcDomain) {
WARN(ALGO, "AllreducePacket requires workSize to match nRanksPerIpcDomain, got workSize=", ctx->workSize,
if (ctx->worldSize != ctx->nRanksPerIpcDomain) {
WARN(ALGO, "AllreducePacket requires worldSize to match nRanksPerIpcDomain, got worldSize=", ctx->worldSize,
", nRanksPerIpcDomain=", ctx->nRanksPerIpcDomain);
return CommResult::CommInvalidArgument;
}
std::pair<int, int> blockAndThreadNum = {nBlocks, nThreadsPerBlock};
if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) {
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->nRanksPerIpcDomain, ctx->workSize, dtype);
blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->nRanksPerIpcDomain, ctx->worldSize, dtype);
} else {
const int nPeers = ctx->nRanksPerIpcDomain - 1;
if (blockAndThreadNum.first < nPeers) {
return CommResult::CommInvalidArgument;
}
}
if (blockAndThreadNum.first > maxBlockNum_) {
WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ",
maxBlockNum_, ".");
return CommResult::CommInvalidArgument;
}
size_t sendBytes;
CUdeviceptr sendBasePtr;
@@ -258,7 +263,7 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr<void> ctx_
}
cudaError_t error =
allreduce(input, this->scratchBuffer_, output, ctx->memoryChannelDeviceHandles.get(), nullptr, nullptr, nullptr,
channelInOffset, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize,
channelInOffset, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize,
inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_,
blockAndThreadNum.first, blockAndThreadNum.second);
if (error != cudaSuccess) {
@@ -273,7 +278,7 @@ std::shared_ptr<void> AllreducePacket::initAllreduceContext(std::shared_ptr<Comm
auto ctx = std::make_shared<AlgorithmCtx>();
const int nChannelsPerConnection = maxBlockNum_;
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
ctx->memorySemaphores = this->memorySemaphores_;
ctx->registeredMemories = this->registeredMemories_;

View File

@@ -185,7 +185,7 @@ CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr<void> ctx, c
}
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(),
this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, 0, algoCtx->rank,
algoCtx->nRanksPerIpcDomain, algoCtx->workSize, inputSize, stream, nullptr, 0, 0,
algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize, stream, nullptr, 0, 0,
numBlocksAndThreads.first, numBlocksAndThreads.second);
if (error != cudaSuccess) {
WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error));
@@ -202,7 +202,7 @@ std::shared_ptr<void> AllreduceRsAg::initAllreduceContext(std::shared_ptr<Commun
size_t, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
ctx->memorySemaphores = this->scratchSemaphores_;

View File

@@ -288,7 +288,7 @@ CommResult AllreduceRsAgPipeline::allreduceKernelFunc(
std::pair<int, int> numBlocksAndThreads = {nBlocks, nThreadsPerBlock};
cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(),
this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, this->scratchBufferSize_,
algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->workSize, inputSize, stream,
algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize, stream,
nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second);
if (error != cudaSuccess) {
WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error));
@@ -305,7 +305,7 @@ std::shared_ptr<void> AllreduceRsAgPipeline::initAllreduceContext(std::shared_pt
void*, size_t, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
ctx->memorySemaphores = this->scratchSemaphores_;

View File

@@ -172,7 +172,7 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr<void
}
cudaError_t error =
allreduce(input, nullptr, output, this->baseMemoryChannelHandles_.get(), algoCtx->remoteMemoryHandles.get(),
nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->workSize, inputSize,
nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize,
stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second);
if (error != cudaSuccess) {
if (error == cudaErrorInvalidValue) {
@@ -203,7 +203,7 @@ std::shared_ptr<void> AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt
void* output, size_t size, DataType) {
auto ctx = std::make_shared<AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->worldSize = comm->bootstrap()->getNranks();
ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain();
ctx->memorySemaphores = this->semaphores_;

View File

@@ -30,9 +30,7 @@ class AllreduceAllpairPacket : public AlgorithmBuilder {
void* scratchBuffer_;
size_t scratchBufferSize_;
const int nSegmentsForScratchBuffer_ = 2;
// Must be at least MAX_IPC_DOMAIN_NRANKS-1 so the adapter can launch one
// block per peer at MNNVL scale.
const int maxBlockNum_ = MAX_IPC_DOMAIN_NRANKS - 1;
const int maxBlockNum_ = 64;
std::vector<Connection> conns_;
std::vector<std::shared_ptr<MemoryDevice2DeviceSemaphore>> memorySemaphores_;
std::vector<RegisteredMemory> registeredMemories_;

View File

@@ -79,7 +79,7 @@ std::shared_ptr<DeviceHandle<BaseMemoryChannel>> setupBaseMemoryChannelDeviceHan
class AlgorithmCtx {
public:
int rank;
int workSize;
int worldSize;
int nRanksPerIpcDomain;
std::vector<RegisteredMemory> registeredMemories;