From f803eff8b92275dabf71093eb33cb00ca670d099 Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Tue, 24 Feb 2026 04:05:01 +0000 Subject: [PATCH] Use multiple thread blocks; Add peer-parallel kernels --- python/mscclpp/ext/alltoallv_single.py | 63 ++++--- .../alltoallv/alltoallv_fullmesh.cu | 43 +++-- .../include/alltoallv/alltoallv_kernel.hpp | 171 ++++++++++++++++++ test/mscclpp-test/alltoallv_test.cu | 48 ++++- 4 files changed, 281 insertions(+), 44 deletions(-) diff --git a/python/mscclpp/ext/alltoallv_single.py b/python/mscclpp/ext/alltoallv_single.py index 554eeab1..acf62675 100644 --- a/python/mscclpp/ext/alltoallv_single.py +++ b/python/mscclpp/ext/alltoallv_single.py @@ -149,6 +149,22 @@ class MscclppAlltoAllV: self._d_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda') self._d_remote_recv_displs = torch.zeros(self._world_size, dtype=torch.int64, device='cuda') + # Cache for split sizes to avoid redundant bootstrap exchanges and GPU copies. + # Key: (tuple(send_counts_bytes), tuple(recv_counts_bytes)) + self._cached_splits_key = None + self._cached_input_size = 0 + self._cached_output_size = 0 + self._cached_total_output_elems = 0 + self._cached_dtype = None + # Pre-built extras dict (GPU pointers don't change) + self._extras = { + "sendCounts": self._d_send_counts.data_ptr(), + "sendDispls": self._d_send_displs.data_ptr(), + "recvCounts": self._d_recv_counts.data_ptr(), + "recvDispls": self._d_recv_displs.data_ptr(), + "remoteRecvDispls": self._d_remote_recv_displs.data_ptr(), + } + @property def rank(self) -> int: return self._rank @@ -219,35 +235,32 @@ class MscclppAlltoAllV: send_displs_bytes = [d * elem_size for d in send_displs] recv_counts_bytes = [s * elem_size for s in output_split_sizes] recv_displs_bytes = [d * elem_size for d in recv_displs] - - # Copy to GPU - self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64)) - self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64)) - self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64)) - self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64)) - - # Exchange recv displacements with all peers so each rank knows where to - # write in the remote output buffer. remoteRecvDispls[peer] = peer's - # recvDispls[rank], i.e. the offset in peer's output where our data goes. - remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes) - self._d_remote_recv_displs.copy_(torch.tensor(remote_recv_displs, dtype=torch.int64)) + + # Fast path: skip GPU copies + bootstrap exchange if split sizes unchanged + splits_key = (tuple(send_counts_bytes), tuple(recv_counts_bytes)) + if splits_key != self._cached_splits_key: + # Copy counts/displacements to GPU + self._d_send_counts.copy_(torch.tensor(send_counts_bytes, dtype=torch.int64)) + self._d_send_displs.copy_(torch.tensor(send_displs_bytes, dtype=torch.int64)) + self._d_recv_counts.copy_(torch.tensor(recv_counts_bytes, dtype=torch.int64)) + self._d_recv_displs.copy_(torch.tensor(recv_displs_bytes, dtype=torch.int64)) + + # Exchange recv displacements with peers via bootstrap + remote_recv_displs = self._exchange_recv_displs(recv_displs_bytes) + self._d_remote_recv_displs.copy_(torch.tensor(remote_recv_displs, dtype=torch.int64)) + + # Cache for subsequent calls + self._cached_splits_key = splits_key + self._cached_input_size = sum(send_counts_bytes) + self._cached_output_size = sum(recv_counts_bytes) # Get stream if stream is None: stream = torch.cuda.current_stream() cuda_stream = stream.cuda_stream - - # Build extras dict with GPU pointers - extras = { - "sendCounts": self._d_send_counts.data_ptr(), - "sendDispls": self._d_send_displs.data_ptr(), - "recvCounts": self._d_recv_counts.data_ptr(), - "recvDispls": self._d_recv_displs.data_ptr(), - "remoteRecvDispls": self._d_remote_recv_displs.data_ptr(), - } - - input_size = sum(send_counts_bytes) - output_size = sum(recv_counts_bytes) + + input_size = self._cached_input_size + output_size = self._cached_output_size # Execute the optimized kernel result = self._algo.execute( @@ -262,7 +275,7 @@ class MscclppAlltoAllV: None, # executor (not needed for native algos) 0, # nblocks (auto) 0, # nthreads_per_block (auto) - extras, + self._extras, ) if result != 0: diff --git a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu index 2e6ffbe8..4b129fd2 100644 --- a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu +++ b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu @@ -31,6 +31,7 @@ struct AllToAllVContext { std::vector memoryChannels; std::vector> memorySemaphores; std::shared_ptr> memoryChannelDeviceHandles; + std::shared_ptr deviceSyncer; // GPU-allocated, for multi-block grid sync }; AlltoallvFullmesh::~AlltoallvFullmesh() = default; @@ -102,36 +103,39 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc( // Use maximum threads (1024) for best bandwidth utilization const int threadsPerBlock = (nThreadsPerBlock > 0 && nThreadsPerBlock <= 1024) ? nThreadsPerBlock : 1024; - // Size-adaptive algorithm selection based on message size and world size: - // - Small messages (<1MB avg): use basic kernel (lower latency) - // - Large messages (>=1MB avg) with small world (<=16): use pipelined kernel - // - Large messages (>=1MB avg) with large world (>16): use ring kernel (avoids congestion) + // Peer-parallel algorithm: blocks assigned round-robin to peers so ALL + // NVLink connections are active simultaneously. Critical for 4+ GPU systems. + // + // Small messages (<1MB avg): nPeers blocks (1 per peer, no barrier) + // Large messages (>=1MB avg): nPeers * blocksPerPeer (barrier-based) constexpr size_t SIZE_THRESHOLD = 1 << 20; // 1MB - constexpr int WORLD_SIZE_THRESHOLD = 16; size_t avgMsgSize = inputSize / worldSize; + int nPeers = worldSize - 1; + if (nPeers < 1) nPeers = 1; if (avgMsgSize < SIZE_THRESHOLD) { - // Small messages: use basic kernel for lower latency - alltoallvKernel<<<1, threadsPerBlock, 0, stream>>>( - algoCtx->memoryChannelDeviceHandles.get(), - rank, worldSize, - input, output, - d_sendCounts, d_sendDispls, - d_recvCounts, d_recvDispls, - d_remoteRecvDispls); - } else if (worldSize > WORLD_SIZE_THRESHOLD) { - // Large messages + large world: use ring kernel to avoid congestion - alltoallvRingKernel<<<1, threadsPerBlock, 0, stream>>>( + // Small messages: 1 block per peer, parallel signal/wait, no barrier + int numBlocks = nPeers; + alltoallvPeerParallelKernel<<>>( algoCtx->memoryChannelDeviceHandles.get(), + algoCtx->deviceSyncer.get(), rank, worldSize, input, output, d_sendCounts, d_sendDispls, d_recvCounts, d_recvDispls, d_remoteRecvDispls); } else { - // Large messages + small world: use pipelined chunked kernel - alltoallvPipelinedKernel<<<1, threadsPerBlock, 0, stream>>>( + // Large messages: multiple blocks per peer for maximum put bandwidth. + // Cap total blocks to avoid excessive barrier overhead. + int blocksPerPeer = (nBlocks > 0 && nBlocks <= 128) + ? ((nBlocks + nPeers - 1) / nPeers) // user-specified total → per-peer + : ALLTOALLV_DEFAULT_BLOCKS_PER_PEER; + int numBlocks = nPeers * blocksPerPeer; + if (numBlocks > 128) numBlocks = (128 / nPeers) * nPeers; // keep multiple of nPeers + if (numBlocks < nPeers) numBlocks = nPeers; + alltoallvPeerParallelKernel<<>>( algoCtx->memoryChannelDeviceHandles.get(), + algoCtx->deviceSyncer.get(), rank, worldSize, input, output, d_sendCounts, d_sendDispls, @@ -176,6 +180,9 @@ std::shared_ptr AlltoallvFullmesh::initAlltoallvContext( // Setup device handles ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels); + // Allocate GPU DeviceSyncer for multi-block grid-wide barrier (zero-initialized) + ctx->deviceSyncer = mscclpp::detail::gpuCallocShared(); + // Keep registered memory references to prevent deallocation ctx->registeredMemories = std::move(remoteOutputMemories); ctx->registeredMemories.push_back(inputBufRegMem); diff --git a/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp b/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp index b690d204..d0d5c88b 100644 --- a/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp +++ b/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp @@ -20,6 +20,177 @@ namespace collective { // Large enough to amortize overhead, small enough for good memory patterns constexpr size_t ALLTOALLV_CHUNK_SIZE = 1 << 20; +// Default number of blocks for multi-block kernels. +// Tuned for H100 (132 SMs). Enough to saturate NVLink bandwidth without +// excessive DeviceSyncer overhead. +constexpr int ALLTOALLV_DEFAULT_NBLOCKS = 24; + +// Default blocks per peer for the peer-parallel kernel. +// Controls how many thread blocks cooperate on each peer's data transfer. +constexpr int ALLTOALLV_DEFAULT_BLOCKS_PER_PEER = 16; + +/** + * Peer-parallel AllToAllV kernel for maximum throughput with multiple GPUs. + * + * Unlike the sequential multi-block kernel that processes one peer at a time, + * this kernel assigns blocks to peers round-robin so ALL NVLink connections + * are active simultaneously. This is critical for 4+ GPU systems where the + * per-peer bandwidth is a fraction of aggregate NVLink bandwidth. + * + * Block assignment: block i handles peer (i % nPeers). Blocks assigned to the + * same peer cooperate using local thread IDs within the peer's block group. + * + * Signal/wait is also parallelized: each peer's primary block (localBlockIdx==0) + * independently signals and waits, overlapping wait latencies across peers: + * total wait = O(max) instead of O(sum). + * + * For small messages: launch with nPeers blocks (1 per peer, __syncthreads only) + * For large messages: launch with nPeers*K blocks (K per peer, DeviceSyncer barrier) + * + * Launch config: <<>> where numBlocks >= nPeers + */ +__global__ void __launch_bounds__(1024) + alltoallvPeerParallelKernel(DeviceHandle* memoryChannels, + DeviceSyncer* syncer, + int rank, + int worldSize, + const void* sendBuff, + void* recvBuff, + const size_t* sendCounts, + const size_t* sendDispls, + const size_t* recvCounts, + const size_t* recvDispls, + const size_t* remoteRecvDispls) { + const int nPeers = worldSize - 1; + + // Handle trivial case (single rank, no peers) + if (nPeers == 0) { + const int gtid = threadIdx.x + blockIdx.x * blockDim.x; + const int nThreads = blockDim.x * gridDim.x; + if (sendCounts[rank] > 0) { + mscclpp::copy((char*)recvBuff + recvDispls[rank], + (void*)((const char*)sendBuff + sendDispls[rank]), + sendCounts[rank], gtid, nThreads); + } + return; + } + + // Phase 1: Local copy — all blocks cooperate using global thread IDs + const int gtid = threadIdx.x + blockIdx.x * blockDim.x; + const int nThreads = blockDim.x * gridDim.x; + if (sendCounts[rank] > 0) { + mscclpp::copy((char*)recvBuff + recvDispls[rank], + (void*)((const char*)sendBuff + sendDispls[rank]), + sendCounts[rank], gtid, nThreads); + } + + // Phase 2: Per-peer remote puts — blocks assigned round-robin to peers. + // Block i handles peer (i % nPeers). Multiple blocks for the same peer + // cooperate using local thread IDs within the peer's block group. + const int myPeerIdx = blockIdx.x % nPeers; + const int localBlockIdx = blockIdx.x / nPeers; + const int nBlocksForMyPeer = ((int)gridDim.x - myPeerIdx + nPeers - 1) / nPeers; + + const int localTid = threadIdx.x + localBlockIdx * blockDim.x; + const int nLocalThreads = nBlocksForMyPeer * blockDim.x; + + const int peer = myPeerIdx < rank ? myPeerIdx : myPeerIdx + 1; + + if (sendCounts[peer] > 0) { + memoryChannels[myPeerIdx].put( + remoteRecvDispls[peer], // dst offset in peer's buffer + sendDispls[peer], // src offset in our buffer + sendCounts[peer], // size + localTid, // thread id within peer's block group + nLocalThreads // total threads for this peer + ); + } + + // Phase 3: Synchronization + // - Multiple blocks per peer (gridDim.x > nPeers): grid-wide barrier to ensure + // all blocks' put contributions complete before any signaling. + // - Exactly one block per peer: __syncthreads() suffices (no cross-block deps). + if ((int)gridDim.x > nPeers) { + syncer->sync(gridDim.x); + } else { + __syncthreads(); + } + + // Phase 4: Signal and wait — parallelized across peers. + // Each peer's primary block (localBlockIdx==0, thread 0) independently + // signals and waits. Wait latencies overlap: O(max) instead of O(sum). + if (threadIdx.x == 0 && localBlockIdx == 0) { + memoryChannels[myPeerIdx].signal(); + if (recvCounts[peer] > 0) { + memoryChannels[myPeerIdx].wait(); + } + } +} + +/** + * Legacy multi-block AllToAllV kernel (sequential peers). + * + * All thread blocks cooperate on each peer's data transfer using global thread IDs. + * Peers are processed sequentially. Kept for comparison; prefer alltoallvPeerParallelKernel. + * + * Launch config: <<>> + */ +__global__ void __launch_bounds__(1024) + alltoallvMultiBlockKernel(DeviceHandle* memoryChannels, + DeviceSyncer* syncer, + int rank, + int worldSize, + const void* sendBuff, + void* recvBuff, + const size_t* sendCounts, + const size_t* sendDispls, + const size_t* recvCounts, + const size_t* recvDispls, + const size_t* remoteRecvDispls) { + const int gtid = threadIdx.x + blockIdx.x * blockDim.x; + const int nThreads = blockDim.x * gridDim.x; + const int nPeers = worldSize - 1; + + // Phase 1: Local copy — all threads across all blocks cooperate + if (sendCounts[rank] > 0) { + mscclpp::copy((char*)recvBuff + recvDispls[rank], + (void*)((const char*)sendBuff + sendDispls[rank]), + sendCounts[rank], gtid, nThreads); + } + + // Phase 2: Remote puts — all blocks cooperate on each peer's transfer + for (int peerIdx = 0; peerIdx < nPeers; peerIdx++) { + int peer = peerIdx < rank ? peerIdx : peerIdx + 1; + int chanIdx = peerIdx; + + if (sendCounts[peer] > 0) { + memoryChannels[chanIdx].put( + remoteRecvDispls[peer], + sendDispls[peer], + sendCounts[peer], + gtid, + nThreads + ); + } + } + + // Phase 3: Grid-wide barrier + syncer->sync(gridDim.x); + + // Phase 4: Signal all peers, then wait (single thread) + if (gtid == 0) { + for (int peerIdx = 0; peerIdx < nPeers; peerIdx++) { + memoryChannels[peerIdx].signal(); + } + for (int peerIdx = 0; peerIdx < nPeers; peerIdx++) { + int peer = peerIdx < rank ? peerIdx : peerIdx + 1; + if (recvCounts[peer] > 0) { + memoryChannels[peerIdx].wait(); + } + } + } +} + /** * High-performance AllToAllV kernel using maximum thread parallelism. * diff --git a/test/mscclpp-test/alltoallv_test.cu b/test/mscclpp-test/alltoallv_test.cu index 2d3740a3..a813e703 100644 --- a/test/mscclpp-test/alltoallv_test.cu +++ b/test/mscclpp-test/alltoallv_test.cu @@ -32,6 +32,9 @@ static size_t* d_remoteRecvDispls; // peer's recvDispls[rank] for each peer // Device array for memory channels (used by library kernels) static DeviceHandle* d_memoryChannels; +// GPU-allocated DeviceSyncer for multi-block kernel +static mscclpp::DeviceSyncer* d_deviceSyncer; + class AllToAllVTestColl : public BaseTestColl { public: AllToAllVTestColl() = default; @@ -88,6 +91,42 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) { d_sendCounts, d_sendDispls, d_recvCounts, d_recvDispls, d_remoteRecvDispls); + } else if (kernelNum == 3) { + // Use legacy multi-block kernel (sequential peers) + const int nBlocks = mscclpp::collective::ALLTOALLV_DEFAULT_NBLOCKS; + mscclpp::collective::alltoallvMultiBlockKernel<<>>( + d_memoryChannels, + d_deviceSyncer, + rank, worldSize, + localSendBuffV, localRecvBuffV, + d_sendCounts, d_sendDispls, + d_recvCounts, d_recvDispls, + d_remoteRecvDispls); + } else if (kernelNum == 4) { + // Peer-parallel kernel: small messages (1 block/peer, no barrier) + const int nPeers = worldSize - 1; + const int nBlocks = (nPeers > 0) ? nPeers : 1; + mscclpp::collective::alltoallvPeerParallelKernel<<>>( + d_memoryChannels, + d_deviceSyncer, + rank, worldSize, + localSendBuffV, localRecvBuffV, + d_sendCounts, d_sendDispls, + d_recvCounts, d_recvDispls, + d_remoteRecvDispls); + } else if (kernelNum == 5) { + // Peer-parallel kernel: large messages (multiple blocks/peer, barrier) + const int nPeers = worldSize - 1; + const int blocksPerPeer = mscclpp::collective::ALLTOALLV_DEFAULT_BLOCKS_PER_PEER; + const int nBlocks = (nPeers > 0) ? nPeers * blocksPerPeer : blocksPerPeer; + mscclpp::collective::alltoallvPeerParallelKernel<<>>( + d_memoryChannels, + d_deviceSyncer, + rank, worldSize, + localSendBuffV, localRecvBuffV, + d_sendCounts, d_sendDispls, + d_recvCounts, d_recvDispls, + d_remoteRecvDispls); } } @@ -183,7 +222,10 @@ std::vector AllToAllVTestColl::getKernelRestrictions() { return { {0, "alltoallvKernel", true, 1, 4 * worldSize_}, {1, "alltoallvRingKernel", true, 1, 4 * worldSize_}, - {2, "alltoallvPipelinedKernel", true, 1, 4 * worldSize_} + {2, "alltoallvPipelinedKernel", true, 1, 4 * worldSize_}, + {3, "alltoallvMultiBlockKernel", true, 1, 4 * worldSize_}, + {4, "alltoallvPeerParallel(small)", true, 1, 4 * worldSize_}, + {5, "alltoallvPeerParallel(large)", true, 1, 4 * worldSize_} }; } @@ -229,6 +271,10 @@ void AllToAllVTestEngine::allocateBuffer() { // Allocate device array for memory channels CUDATHROW(cudaMalloc(&d_memoryChannels, args_.totalRanks * sizeof(DeviceHandle))); + + // Allocate GPU DeviceSyncer for multi-block kernel (zero-initialized) + CUDATHROW(cudaMalloc(&d_deviceSyncer, sizeof(mscclpp::DeviceSyncer))); + CUDATHROW(cudaMemset(d_deviceSyncer, 0, sizeof(mscclpp::DeviceSyncer))); } void AllToAllVTestEngine::setupConnections() {