diff --git a/examples/customized-collective-algorithm/customized_allgather.cu b/examples/customized-collective-algorithm/customized_allgather.cu index 02df3685..13802f80 100644 --- a/examples/customized-collective-algorithm/customized_allgather.cu +++ b/examples/customized-collective-algorithm/customized_allgather.cu @@ -79,7 +79,7 @@ __global__ void __launch_bounds__(1024) struct Context { int rank; - int workSize; + int worldSize; int nRanksPerNode; std::vector registeredMemories; @@ -140,7 +140,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { size_t inputSize, cudaStream_t stream) { auto algoCtx = std::static_pointer_cast(ctx); int rank = algoCtx->rank; - int worldSize = algoCtx->workSize; + int worldSize = algoCtx->worldSize; int nThreadsPerBlock = (worldSize - 1) * WARP_SIZE; allgather<<<1, nThreadsPerBlock, 0, stream>>>(algoCtx->portChannelDeviceHandles.get(), rank, inputSize); @@ -154,16 +154,16 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { void* output, size_t inputSize, mscclpp::DataType dtype) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); // register memories mscclpp::RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc); mscclpp::RegisteredMemory outputBufRegMem = - comm->registerMemory(output, inputSize * ctx->workSize, mscclpp::Transport::CudaIpc); + comm->registerMemory(output, inputSize * ctx->worldSize, mscclpp::Transport::CudaIpc); std::vector> remoteRegMemories; - for (int i = 0; i < ctx->workSize; i++) { + for (int i = 0; i < ctx->worldSize; i++) { if (i == ctx->rank) continue; comm->sendMemory(outputBufRegMem, i, 0); remoteRegMemories.push_back(comm->recvMemory(i, 0)); diff --git a/examples/torch-integration/customized_allgather.cu b/examples/torch-integration/customized_allgather.cu index 907b3ada..5ba2935f 100644 --- a/examples/torch-integration/customized_allgather.cu +++ b/examples/torch-integration/customized_allgather.cu @@ -47,7 +47,7 @@ __global__ void __launch_bounds__(1024) struct Context { int rank; - int workSize; + int worldSize; int nRanksPerNode; std::vector registeredMemories; @@ -108,7 +108,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { cudaStream_t stream) { auto algoCtx = std::static_pointer_cast(ctx); int rank = algoCtx->rank; - int worldSize = algoCtx->workSize; + int worldSize = algoCtx->worldSize; int nThreadsPerBlock = (worldSize - 1) * WARP_SIZE; allgather<<<1, nThreadsPerBlock, 0, stream>>>(algoCtx->portChannelDeviceHandles.get(), rank, inputBytes); @@ -122,16 +122,16 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { void* output, size_t inputBytes, mscclpp::DataType dtype) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); // register memories mscclpp::RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputBytes, mscclpp::Transport::CudaIpc); mscclpp::RegisteredMemory outputBufRegMem = - comm->registerMemory(output, inputBytes * ctx->workSize, mscclpp::Transport::CudaIpc); + comm->registerMemory(output, inputBytes * ctx->worldSize, mscclpp::Transport::CudaIpc); std::vector> remoteRegMemories; - for (int i = 0; i < ctx->workSize; i++) { + for (int i = 0; i < ctx->worldSize; i++) { if (i == ctx->rank) continue; comm->sendMemory(outputBufRegMem, i, 0); remoteRegMemories.push_back(comm->recvMemory(i, 0)); diff --git a/python/mscclpp/language/channel.py b/python/mscclpp/language/channel.py index 23d76eda..de0f65c5 100644 --- a/python/mscclpp/language/channel.py +++ b/python/mscclpp/language/channel.py @@ -78,6 +78,7 @@ class MemoryChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False): """Wait for a signal through the memory channel. @@ -99,6 +100,7 @@ class MemoryChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) + get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type) def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Retrieve data from remote memory to local memory. @@ -508,6 +510,7 @@ class PortChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = SignalOperation(tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def wait(self, tb: int, data_sync: SyncType = SyncType.both): """Wait for a signal through the port channel. @@ -527,6 +530,7 @@ class PortChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = WaitOperation(tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) + get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type) def flush(self, tb: int, data_sync: SyncType = SyncType.both): """Flush pending operations through the port channel. @@ -636,6 +640,7 @@ class PortChannel: ) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): """Send data from local memory to remote memory with signal and flush. @@ -681,6 +686,7 @@ class PortChannel: ) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): """Transfer data from local buffer to remote scratch buffer in packet format. diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index c29e9ab7..825a9d40 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -10,6 +10,7 @@ from mscclpp.language.rank import Semaphore from mscclpp.language.collectives import * from mscclpp.language.utils import AlgoSpec, ReplicationPolicy from typing import List +from collections import defaultdict import json @@ -112,6 +113,9 @@ class CollectiveProgram: self.loop_context = None + self._signal_counts = defaultdict(int) + self._wait_counts = defaultdict(int) + @classmethod def from_spec(cls, spec: AlgoSpec): """Initialize a new CollectiveProgram from an algorithm specification. @@ -206,7 +210,35 @@ class CollectiveProgram: else: self.gpus[rank].add_operation(tb, operation) + def register_signal(self, src_rank, dst_rank, channel_type): + """Record that `src_rank` issued a signal targeting `dst_rank` over `channel_type`.""" + self._signal_counts[(src_rank, dst_rank, channel_type)] += 1 + + def register_wait(self, src_rank, dst_rank, channel_type): + """Record that `src_rank` performed a wait for `dst_rank` over `channel_type`.""" + self._wait_counts[(src_rank, dst_rank, channel_type)] += 1 + + def validate_signal_wait_pairing(self): + """Validate that every signal issued by a rank is matched by a wait on the peer rank. + + For each (src_rank, dst_rank, channel_type) triple, the number of signals sent by + `src_rank` to `dst_rank` must equal the number of waits performed by `dst_rank` + for `src_rank` on a channel of the same type. Raises RuntimeError on mismatch. + """ + keys = set(self._signal_counts.keys()) | {(dst, src, t) for (src, dst, t) in self._wait_counts.keys()} + for src_rank, dst_rank, channel_type in keys: + signals = self._signal_counts.get((src_rank, dst_rank, channel_type), 0) + waits = self._wait_counts.get((dst_rank, src_rank, channel_type), 0) + if signals != waits: + raise RuntimeError( + f"Signal/Wait mismatch on {channel_type}: rank {src_rank} issues {signals} " + f"signal(s) to rank {dst_rank}, but rank {dst_rank} performs {waits} wait(s) " + f"for rank {src_rank}. Every signal must be matched by a corresponding wait " + f"on the peer rank over a channel of the same type." + ) + def post_process_operations(self): + self.validate_signal_wait_pairing() for gpu in self.gpus: if self.instr_fusion: gpu.optimize_operations() diff --git a/src/ext/collectives/allgather/allgather_fullmesh.cu b/src/ext/collectives/allgather/allgather_fullmesh.cu index 570a2d61..d9d52630 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh.cu @@ -127,11 +127,11 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr ct if ((char*)input == (char*)output + rank * inputSize) { allgatherFullmesh<<>>( (void*)input, this->scratchBuffer_, (void*)output, ctx->memoryChannelDeviceHandles.get(), rank, - ctx->nRanksPerIpcDomain, ctx->workSize, nElem); + ctx->nRanksPerIpcDomain, ctx->worldSize, nElem); } else { allgatherFullmesh<<>>( (void*)input, this->scratchBuffer_, (void*)output, ctx->memoryChannelDeviceHandles.get(), rank, - ctx->nRanksPerIpcDomain, ctx->workSize, nElem); + ctx->nRanksPerIpcDomain, ctx->worldSize, nElem); } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { @@ -147,7 +147,7 @@ std::shared_ptr AllgatherFullmesh::initAllgatherContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup semaphores diff --git a/src/ext/collectives/allgather/allgather_fullmesh_2.cu b/src/ext/collectives/allgather/allgather_fullmesh_2.cu index f344824f..2217edc7 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh_2.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh_2.cu @@ -139,11 +139,11 @@ CommResult AllgatherFullmesh2::allgatherKernelFunc(const std::shared_ptr c size_t channelOutOffset = *static_cast(ctx->extras["channel_out_offset"].get()); if ((char*)input == (char*)output + rank * inputSize) { allgatherFullmesh2<<>>( - (void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->workSize, + (void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->worldSize, ctx->nRanksPerIpcDomain, nElem); } else { allgatherFullmesh2<<>>( - (void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->workSize, + (void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->worldSize, ctx->nRanksPerIpcDomain, nElem); } cudaError_t err = cudaGetLastError(); @@ -158,7 +158,7 @@ std::shared_ptr AllgatherFullmesh2::initAllgatherContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup semaphores diff --git a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu index 47c4f61d..3b2375a6 100644 --- a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu @@ -7,7 +7,7 @@ #include "allreduce/allreduce_allpair_packet.hpp" #include "allreduce/common.hpp" #include "collective_utils.hpp" -#include "debug.h" +#include "logger.hpp" namespace mscclpp { namespace collective { @@ -24,22 +24,30 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand size_t scratchBaseOffset = (flag % numScratchBuff) ? (scratchBufferSize / numScratchBuff) : 0; size_t channelScratchOffset = scratchBaseOffset; - const int nBlocksPerPeer = gridDim.x / nPeers; - const int localBlockIdx = blockIdx.x % nBlocksPerPeer; - const int tid = threadIdx.x + localBlockIdx * blockDim.x; - const int peerIdx = blockIdx.x / nBlocksPerPeer; - size_t srcOffset = channelDataOffset; + const int tid = threadIdx.x + blockIdx.x * blockDim.x; size_t scratchOffset = channelScratchOffset + rank * nelems * sizeof(LL8Packet); void* scratchBuff = (void*)((char*)scratch + channelScratchOffset); uint32_t* src = (uint32_t*)((char*)buff); uint32_t* dst = (uint32_t*)((char*)resultBuff); - // step 1: write data to each peer's scratch buffer - memoryChannels[peerIdx].putPackets(scratchOffset, srcOffset, nelems * sizeof(uint32_t), tid, - blockDim.x * nBlocksPerPeer, flag); + const int warpId = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int nWarpsPerBlock = blockDim.x / WARP_SIZE; + // Assign one warp in every block to each peer. Each peer warp sends the + // same block-owned stripe, so nBlocks only partitions data and no longer + // needs to be grouped by nPeers. + if (warpId < nPeers) { + memoryChannels[warpId].putPackets(scratchOffset, channelDataOffset, nelems * sizeof(uint32_t), + lane + blockIdx.x * WARP_SIZE, gridDim.x * WARP_SIZE, flag); + } + // Safe for in-place allreduce: all peer warps must finish reading src for + // this block's stripe before any warp writes reduced data back to dst/src. + __syncthreads(); - // step 2: Reduce Data - for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) { + // Split the same sent stream across all warps for reduction. warpId selects + // which strided subset to reduce while lane preserves coalesced packet reads. + for (size_t idx = lane + blockIdx.x * WARP_SIZE + warpId * WARP_SIZE * gridDim.x; idx < nelems; + idx += nWarpsPerBlock * WARP_SIZE * gridDim.x) { uint32_t data = src[idx]; using AccRaw = std::conditional_t, uint32_t, mscclpp::VectorType>; @@ -56,16 +64,16 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand if (threadIdx.x == 0) { ((uint32_t*)flags)[blockIdx.x] = flag + 1; } - if (blockIdx.x == 0 && threadIdx.x >= gridDim.x && threadIdx.x < flagSize / sizeof(uint32_t)) { - ((uint32_t*)flags)[threadIdx.x] = flag + 1; + if (tid >= gridDim.x && tid < flagSize / sizeof(uint32_t)) { + ((uint32_t*)flags)[tid] = flag + 1; } } -inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int worldSize) { - if (inputSize < worldSize * sizeof(int)) { - return {worldSize - 1, 32}; +inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int nRanksPerIpcDomain) { + if (inputSize < nRanksPerIpcDomain * sizeof(int)) { + return {nRanksPerIpcDomain - 1, (nRanksPerIpcDomain - 1) * WARP_SIZE}; } - return {(worldSize - 1) * 4, 512}; + return {(nRanksPerIpcDomain - 1) * 4, 512}; } template @@ -77,9 +85,6 @@ struct AllpairAdapter { int nThreadsPerBlock = 0) { using ChannelType = DeviceHandle; const size_t nelems = inputSize / sizeof(T); - // Round nBlocks to multiple of nPeers so every block maps to a valid peer. - const int nPeers = nRanksPerIpcDomain - 1; - nBlocks = (nBlocks / nPeers) * nPeers; allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, nRanksPerIpcDomain, worldSize, nelems, numScratchBuff, flags, flagSize); @@ -101,18 +106,27 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr&, DataType accumDtype) { auto algoCtx = std::static_pointer_cast(ctx); - if (algoCtx->workSize != algoCtx->nRanksPerIpcDomain) { - WARN("AllreduceAllpairPacket requires workSize to match nRanksPerIpcDomain, got workSize=%d, nRanksPerIpcDomain=%d", - algoCtx->workSize, algoCtx->nRanksPerIpcDomain); + if (algoCtx->worldSize != algoCtx->nRanksPerIpcDomain) { + WARN(ALGO, + "AllreduceAllpairPacket requires worldSize to match nRanksPerIpcDomain, got worldSize=", algoCtx->worldSize, + ", nRanksPerIpcDomain=", algoCtx->nRanksPerIpcDomain); return CommResult::CommInvalidArgument; } std::pair blockAndThreadNum{nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, algoCtx->nRanksPerIpcDomain); } - // nBlocks must be at least nPeers for allpair — each block maps to one peer. + if (blockAndThreadNum.first > maxBlockNum_) { + WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ", + maxBlockNum_, "."); + return CommResult::CommInvalidArgument; + } const int nPeers = algoCtx->nRanksPerIpcDomain - 1; - if (blockAndThreadNum.first < nPeers) { + // The kernel maps peer sends by warpId, so every peer needs a full warp. + if (blockAndThreadNum.second % WARP_SIZE != 0 || blockAndThreadNum.second / WARP_SIZE < nPeers) { + WARN(ALGO, + "Allpair packet requires at least one full warp per peer, but got nThreadsPerBlock=", blockAndThreadNum.second, + " and nPeers=", nPeers, "."); return CommResult::CommInvalidArgument; } size_t sendBytes; @@ -122,16 +136,17 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr(op, dtype, accumDtype); if (!allreduce) { - WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast(dtype)); + WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), + ", dtype=", static_cast(dtype)); return CommResult::CommInvalidArgument; } cudaError_t error = allreduce(input, this->scratchBuffer_, output, algoCtx->memoryChannelDeviceHandles.get(), nullptr, nullptr, nullptr, channelInOffset, 0, this->scratchBufferSize_, algoCtx->rank, algoCtx->nRanksPerIpcDomain, - algoCtx->workSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, + algoCtx->worldSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { - WARN("AllreducePacket failed with error: %s", cudaGetErrorString(error)); + WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error)); return CommResult::CommUnhandledCudaError; } return CommResult::CommSuccess; @@ -142,7 +157,7 @@ std::shared_ptr AllreduceAllpairPacket::initAllreduceContext(std::shared_p auto ctx = std::make_shared(); const int nChannelsPerConnection = maxBlockNum_; ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->memorySemaphores_; ctx->registeredMemories = this->registeredMemories_; diff --git a/src/ext/collectives/allreduce/allreduce_fullmesh.cu b/src/ext/collectives/allreduce/allreduce_fullmesh.cu index 2790295e..f547ab4f 100644 --- a/src/ext/collectives/allreduce/allreduce_fullmesh.cu +++ b/src/ext/collectives/allreduce/allreduce_fullmesh.cu @@ -223,7 +223,7 @@ CommResult AllreduceFullmesh::allreduceKernelFunc( } cudaError_t error = allreduce(input, this->scratchBuffer_, output, inputChannelHandles.get(), ctx->memoryChannelDeviceHandles.get(), - nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize, + nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); if (error != cudaSuccess) { WARN("AllreduceAllconnect failed with error: %s", cudaGetErrorString(error)); @@ -249,7 +249,7 @@ std::shared_ptr AllreduceFullmesh::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup semaphores diff --git a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu index 04c7f8c9..1edbc011 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu @@ -205,7 +205,7 @@ CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc( } cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->memoryChannelsDeviceHandle_.get(), nullptr, ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_, - ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize, stream, nullptr, 0, 0, + ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { WARN("AllreduceNvlsBlockPipeline failed with error: %s", cudaGetErrorString(error)); @@ -222,7 +222,7 @@ std::shared_ptr AllreduceNvlsBlockPipeline::initAllreduceContext(std::shar void*, size_t, DataType) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup channels diff --git a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu index 1918eef1..98d9e1a3 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu @@ -93,7 +93,7 @@ std::shared_ptr AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr< size_t, DataType) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup channels @@ -123,7 +123,7 @@ CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr } cudaError_t error = allreduce(input, this->scratchBuffer_, output, nullptr, nullptr, ctx->switchChannelDeviceHandles.get(), nullptr, - 0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize, stream, + 0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, 0, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { WARN(ALGO, "AllreduceNvlsPacket failed with error: ", cudaGetErrorString(error)); diff --git a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu index d5bbb2e7..d4492ed5 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu @@ -169,7 +169,7 @@ CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc( } cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->memoryChannelsDeviceHandle_.get(), nullptr, ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_, - ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, inputSize, stream, nullptr, 0, 0, + ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { WARN("AllreduceNvlsWarpPipeline failed with error: %s", cudaGetErrorString(error)); @@ -186,7 +186,7 @@ std::shared_ptr AllreduceNvlsWarpPipeline::initAllreduceContext(std::share void*, size_t, DataType) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup channels diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index 481e8ad8..f76dd079 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -149,17 +149,17 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr ctx_vo // the number of GPUs. Empirically, 32 blocks works well for 4 GPUs and 16 for 8 GPUs, which // follows the formula 128 / nGPUs, clamped to [1, MAX_NBLOCKS]. if (computeCapabilityMajor_ == 10) { - numBlocksAndThreads.first = ::max(1, ::min(128 / ctx->workSize, MAX_NBLOCKS)); + numBlocksAndThreads.first = ::max(1, ::min(128 / ctx->worldSize, MAX_NBLOCKS)); } } if (numBlocksAndThreads.first > MAX_NBLOCKS) { WARN("Number of blocks exceeds maximum supported value of %d", MAX_NBLOCKS); return CommResult::CommInvalidArgument; } - cudaError_t error = - allreduce(nullptr, nullptr, nullptr, this->memoryChannelsDeviceHandle_.get(), nullptr, nvlsChannels, - nvlsOutChannels, channelInOffset, channelOutOffset, 0, ctx->rank, ctx->nRanksPerIpcDomain, - ctx->workSize, inputSize, stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); + cudaError_t error = allreduce(nullptr, nullptr, nullptr, this->memoryChannelsDeviceHandle_.get(), nullptr, + nvlsChannels, nvlsOutChannels, channelInOffset, channelOutOffset, 0, ctx->rank, + ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0, + numBlocksAndThreads.first, numBlocksAndThreads.second); if (error != cudaSuccess) { if (error == cudaErrorNotSupported) { WARN("AllreduceNvls does not support the requested data type."); @@ -185,7 +185,7 @@ std::shared_ptr AllreduceNvls::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); size_t sendBytes, recvBytes; diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index d20625ee..8591c983 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -230,20 +230,25 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_ const std::unordered_map&, DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); - if (ctx->workSize != ctx->nRanksPerIpcDomain) { - WARN(ALGO, "AllreducePacket requires workSize to match nRanksPerIpcDomain, got workSize=", ctx->workSize, + if (ctx->worldSize != ctx->nRanksPerIpcDomain) { + WARN(ALGO, "AllreducePacket requires worldSize to match nRanksPerIpcDomain, got worldSize=", ctx->worldSize, ", nRanksPerIpcDomain=", ctx->nRanksPerIpcDomain); return CommResult::CommInvalidArgument; } std::pair blockAndThreadNum = {nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { - blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->nRanksPerIpcDomain, ctx->workSize, dtype); + blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->nRanksPerIpcDomain, ctx->worldSize, dtype); } else { const int nPeers = ctx->nRanksPerIpcDomain - 1; if (blockAndThreadNum.first < nPeers) { return CommResult::CommInvalidArgument; } } + if (blockAndThreadNum.first > maxBlockNum_) { + WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ", + maxBlockNum_, "."); + return CommResult::CommInvalidArgument; + } size_t sendBytes; CUdeviceptr sendBasePtr; @@ -258,7 +263,7 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_ } cudaError_t error = allreduce(input, this->scratchBuffer_, output, ctx->memoryChannelDeviceHandles.get(), nullptr, nullptr, nullptr, - channelInOffset, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->workSize, + channelInOffset, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { @@ -273,7 +278,7 @@ std::shared_ptr AllreducePacket::initAllreduceContext(std::shared_ptr(); const int nChannelsPerConnection = maxBlockNum_; ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->memorySemaphores_; ctx->registeredMemories = this->registeredMemories_; diff --git a/src/ext/collectives/allreduce/allreduce_rsag.cu b/src/ext/collectives/allreduce/allreduce_rsag.cu index 6fffc4da..1f5d3e5d 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag.cu @@ -185,7 +185,7 @@ CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr ctx, c } cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(), this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, 0, algoCtx->rank, - algoCtx->nRanksPerIpcDomain, algoCtx->workSize, inputSize, stream, nullptr, 0, 0, + algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize, stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); if (error != cudaSuccess) { WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error)); @@ -202,7 +202,7 @@ std::shared_ptr AllreduceRsAg::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->scratchSemaphores_; diff --git a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu index e9d543ea..4b243444 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu @@ -288,7 +288,7 @@ CommResult AllreduceRsAgPipeline::allreduceKernelFunc( std::pair numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(), this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, this->scratchBufferSize_, - algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->workSize, inputSize, stream, + algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize, stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); if (error != cudaSuccess) { WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error)); @@ -305,7 +305,7 @@ std::shared_ptr AllreduceRsAgPipeline::initAllreduceContext(std::shared_pt void*, size_t, DataType) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->scratchSemaphores_; diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index 753ad799..877a722a 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -172,7 +172,7 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptrbaseMemoryChannelHandles_.get(), algoCtx->remoteMemoryHandles.get(), - nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->workSize, inputSize, + nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize, stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); if (error != cudaSuccess) { if (error == cudaErrorInvalidValue) { @@ -203,7 +203,7 @@ std::shared_ptr AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt void* output, size_t size, DataType) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->semaphores_; diff --git a/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp b/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp index d2ea7259..bba82ee5 100644 --- a/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp @@ -30,9 +30,7 @@ class AllreduceAllpairPacket : public AlgorithmBuilder { void* scratchBuffer_; size_t scratchBufferSize_; const int nSegmentsForScratchBuffer_ = 2; - // Must be at least MAX_IPC_DOMAIN_NRANKS-1 so the adapter can launch one - // block per peer at MNNVL scale. - const int maxBlockNum_ = MAX_IPC_DOMAIN_NRANKS - 1; + const int maxBlockNum_ = 64; std::vector conns_; std::vector> memorySemaphores_; std::vector registeredMemories_; diff --git a/src/ext/collectives/include/collective_utils.hpp b/src/ext/collectives/include/collective_utils.hpp index 95ce7f5a..be18477a 100644 --- a/src/ext/collectives/include/collective_utils.hpp +++ b/src/ext/collectives/include/collective_utils.hpp @@ -79,7 +79,7 @@ std::shared_ptr> setupBaseMemoryChannelDeviceHan class AlgorithmCtx { public: int rank; - int workSize; + int worldSize; int nRanksPerIpcDomain; std::vector registeredMemories;