From c1071318c84f968bc292e2ef9b8296ba837d06af Mon Sep 17 00:00:00 2001 From: Caio Rocha <164253795+caiomcbr@users.noreply.github.com> Date: Tue, 19 May 2026 13:06:53 -0700 Subject: [PATCH 1/2] Include a static synchronization check in the DSL. (#806) --- python/mscclpp/language/channel.py | 6 ++++++ python/mscclpp/language/program.py | 32 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/python/mscclpp/language/channel.py b/python/mscclpp/language/channel.py index 23d76eda..de0f65c5 100644 --- a/python/mscclpp/language/channel.py +++ b/python/mscclpp/language/channel.py @@ -78,6 +78,7 @@ class MemoryChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = SignalOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def wait(self, tb: int, data_sync: SyncType = SyncType.both, relaxed: bool = False): """Wait for a signal through the memory channel. @@ -99,6 +100,7 @@ class MemoryChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = WaitOperation(tb_channel_ids, self.channel_type, data_sync, relaxed) get_program().add_operation(self.src_rank, tb, op) + get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type) def get(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int = None, tb_group: ThreadBlockGroup = None): """Retrieve data from remote memory to local memory. @@ -508,6 +510,7 @@ class PortChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = SignalOperation(tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def wait(self, tb: int, data_sync: SyncType = SyncType.both): """Wait for a signal through the port channel. @@ -527,6 +530,7 @@ class PortChannel: tb_channel_ids = get_program().setup_channel(tb, self) op = WaitOperation(tb_channel_ids, self.channel_type, data_sync) get_program().add_operation(self.src_rank, tb, op) + get_program().register_wait(self.src_rank, self.dst_rank, self.channel_type) def flush(self, tb: int, data_sync: SyncType = SyncType.both): """Flush pending operations through the port channel. @@ -636,6 +640,7 @@ class PortChannel: ) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): """Send data from local memory to remote memory with signal and flush. @@ -681,6 +686,7 @@ class PortChannel: ) get_program().add_operation(self.src_rank, tb, op) + get_program().register_signal(self.src_rank, self.dst_rank, self.channel_type) def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int): """Transfer data from local buffer to remote scratch buffer in packet format. diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index c29e9ab7..825a9d40 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -10,6 +10,7 @@ from mscclpp.language.rank import Semaphore from mscclpp.language.collectives import * from mscclpp.language.utils import AlgoSpec, ReplicationPolicy from typing import List +from collections import defaultdict import json @@ -112,6 +113,9 @@ class CollectiveProgram: self.loop_context = None + self._signal_counts = defaultdict(int) + self._wait_counts = defaultdict(int) + @classmethod def from_spec(cls, spec: AlgoSpec): """Initialize a new CollectiveProgram from an algorithm specification. @@ -206,7 +210,35 @@ class CollectiveProgram: else: self.gpus[rank].add_operation(tb, operation) + def register_signal(self, src_rank, dst_rank, channel_type): + """Record that `src_rank` issued a signal targeting `dst_rank` over `channel_type`.""" + self._signal_counts[(src_rank, dst_rank, channel_type)] += 1 + + def register_wait(self, src_rank, dst_rank, channel_type): + """Record that `src_rank` performed a wait for `dst_rank` over `channel_type`.""" + self._wait_counts[(src_rank, dst_rank, channel_type)] += 1 + + def validate_signal_wait_pairing(self): + """Validate that every signal issued by a rank is matched by a wait on the peer rank. + + For each (src_rank, dst_rank, channel_type) triple, the number of signals sent by + `src_rank` to `dst_rank` must equal the number of waits performed by `dst_rank` + for `src_rank` on a channel of the same type. Raises RuntimeError on mismatch. + """ + keys = set(self._signal_counts.keys()) | {(dst, src, t) for (src, dst, t) in self._wait_counts.keys()} + for src_rank, dst_rank, channel_type in keys: + signals = self._signal_counts.get((src_rank, dst_rank, channel_type), 0) + waits = self._wait_counts.get((dst_rank, src_rank, channel_type), 0) + if signals != waits: + raise RuntimeError( + f"Signal/Wait mismatch on {channel_type}: rank {src_rank} issues {signals} " + f"signal(s) to rank {dst_rank}, but rank {dst_rank} performs {waits} wait(s) " + f"for rank {src_rank}. Every signal must be matched by a corresponding wait " + f"on the peer rank over a channel of the same type." + ) + def post_process_operations(self): + self.validate_signal_wait_pairing() for gpu in self.gpus: if self.instr_fusion: gpu.optimize_operations() From 72621e72216c15ac9d636e00704dadc212e110aa Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 20 May 2026 09:29:55 -0700 Subject: [PATCH 2/2] add nBlocks check for allreduce_allpair_packet algo (#807) - Fix the correctness issue for allreduce_allpair_packet algo. Make sure no overwrite for input buffer. Use same tb for send/reduce/write-back. - Check if nBlocks/nthreads validate for packet algorithm. - Add more logs - Modify flag update logic, make it work for the case: nthreadPerNBlock < nflags --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../allreduce/allreduce_allpair_packet.cu | 60 +++++++++++-------- .../collectives/allreduce/allreduce_packet.cu | 5 ++ .../allreduce/allreduce_allpair_packet.hpp | 2 +- 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu index 17bcfc33..faef5459 100644 --- a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu @@ -7,7 +7,7 @@ #include "allreduce/allreduce_allpair_packet.hpp" #include "allreduce/common.hpp" #include "collective_utils.hpp" -#include "debug.h" +#include "logger.hpp" namespace mscclpp { namespace collective { @@ -27,22 +27,30 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand size_t scratchBaseOffset = (flag % numScratchBuff) ? (scratchBufferSize / numScratchBuff) : 0; size_t channelScratchOffset = scratchBaseOffset; - const int nBlocksPerPeer = gridDim.x / nPeers; - const int localBlockIdx = blockIdx.x % nBlocksPerPeer; - const int tid = threadIdx.x + localBlockIdx * blockDim.x; - const int peerIdx = blockIdx.x / nBlocksPerPeer; - size_t srcOffset = channelDataOffset; + const int tid = threadIdx.x + blockIdx.x * blockDim.x; size_t scratchOffset = channelScratchOffset + rank * nelems * sizeof(LL8Packet); void* scratchBuff = (void*)((char*)scratch + channelScratchOffset); uint32_t* src = (uint32_t*)((char*)buff); uint32_t* dst = (uint32_t*)((char*)resultBuff); - // step 1: write data to each peer's scratch buffer - memoryChannels[peerIdx].putPackets(scratchOffset, srcOffset, nelems * sizeof(uint32_t), tid, - blockDim.x * nBlocksPerPeer, flag); + const int warpId = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int nWarpsPerBlock = blockDim.x / WARP_SIZE; + // Assign one warp in every block to each peer. Each peer warp sends the + // same block-owned stripe, so nBlocks only partitions data and no longer + // needs to be grouped by nPeers. + if (warpId < nPeers) { + memoryChannels[warpId].putPackets(scratchOffset, channelDataOffset, nelems * sizeof(uint32_t), + lane + blockIdx.x * WARP_SIZE, gridDim.x * WARP_SIZE, flag); + } + // Safe for in-place allreduce: all peer warps must finish reading src for + // this block's stripe before any warp writes reduced data back to dst/src. + __syncthreads(); - // step 2: Reduce Data - for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nelems; idx += blockDim.x * gridDim.x) { + // Split the same sent stream across all warps for reduction. warpId selects + // which strided subset to reduce while lane preserves coalesced packet reads. + for (size_t idx = lane + blockIdx.x * WARP_SIZE + warpId * WARP_SIZE * gridDim.x; idx < nelems; + idx += nWarpsPerBlock * WARP_SIZE * gridDim.x) { uint32_t data = src[idx]; using AccRaw = std::conditional_t, uint32_t, mscclpp::VectorType>; @@ -59,14 +67,14 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand if (threadIdx.x == 0) { ((uint32_t*)flags)[blockIdx.x] = flag + 1; } - if (blockIdx.x == 0 && threadIdx.x >= gridDim.x && threadIdx.x < flagSize / sizeof(uint32_t)) { - ((uint32_t*)flags)[threadIdx.x] = flag + 1; + if (tid >= gridDim.x && tid < flagSize / sizeof(uint32_t)) { + ((uint32_t*)flags)[tid] = flag + 1; } } inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int worldSize) { if (inputSize < worldSize * sizeof(int)) { - return {worldSize - 1, 32}; + return {worldSize - 1, (worldSize - 1) * WARP_SIZE}; } return {(worldSize - 1) * 4, 512}; } @@ -80,11 +88,6 @@ struct AllpairAdapter { int nThreadsPerBlock = 0) { using ChannelType = DeviceHandle; const size_t nelems = inputSize / sizeof(T); - // Round nBlocks to multiple of nPeers so every block maps to a valid peer. - const int nPeers = worldSize - 1; - if (nPeers > 0) { - nBlocks = (nBlocks / nPeers) * nPeers; - } allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, nRanksPerNode, worldSize, nelems, numScratchBuff, flags, flagSize); @@ -110,9 +113,17 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptrworkSize); } - // nBlocks must be at least nPeers for allpair — each block maps to one peer. + if (blockAndThreadNum.first > maxBlockNum_) { + WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ", + maxBlockNum_, "."); + return CommResult::CommInvalidArgument; + } const int nPeers = algoCtx->nRanksPerNode - 1; - if (nPeers > 0 && blockAndThreadNum.first < nPeers) { + // The kernel maps peer sends by warpId, so every peer needs a full warp. + if (blockAndThreadNum.second % WARP_SIZE != 0 || blockAndThreadNum.second / WARP_SIZE < nPeers) { + WARN(ALGO, + "Allpair packet requires at least one full warp per peer, but got nThreadsPerBlock=", blockAndThreadNum.second, + " and nPeers=", nPeers, "."); return CommResult::CommInvalidArgument; } size_t sendBytes; @@ -122,7 +133,8 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr(op, dtype, accumDtype); if (!allreduce) { - WARN("Unsupported operation or data type for allreduce: op=%d, dtype=%d", op, static_cast(dtype)); + WARN(ALGO, "Unsupported operation or data type for allreduce: op=", static_cast(op), + ", dtype=", static_cast(dtype)); return CommResult::CommInvalidArgument; } cudaError_t error = @@ -131,7 +143,7 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptrworkSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { - WARN("AllreducePacket failed with error: %s", cudaGetErrorString(error)); + WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error)); return CommResult::CommUnhandledCudaError; } return CommResult::CommSuccess; @@ -189,4 +201,4 @@ std::shared_ptr AllreduceAllpairPacket::build() { }); } } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index 6199f192..3c75a746 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -235,6 +235,11 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_ if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->workSize, ctx->nRanksPerNode, dtype); } + if (blockAndThreadNum.first > maxBlockNum_) { + WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ", + maxBlockNum_, "."); + return CommResult::CommInvalidArgument; + } size_t sendBytes; CUdeviceptr sendBasePtr; diff --git a/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp b/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp index 362308b2..64f5ec54 100644 --- a/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp @@ -29,7 +29,7 @@ class AllreduceAllpairPacket : public AlgorithmBuilder { void* scratchBuffer_; size_t scratchBufferSize_; const int nSegmentsForScratchBuffer_ = 2; - const int maxBlockNum_ = 28; + const int maxBlockNum_ = 64; std::vector conns_; std::vector> memorySemaphores_; std::vector registeredMemories_;