From 43980da455e55c1ca77301db0a12739df3631e4c Mon Sep 17 00:00:00 2001 From: Qinghua Zhou Date: Wed, 18 Feb 2026 03:00:29 +0000 Subject: [PATCH] Use maximum threads (1024) for best bandwidth utilization --- .../alltoallv/alltoallv_fullmesh.cu | 15 ++-- .../include/alltoallv/alltoallv_kernel.hpp | 78 +++++++++---------- test/mscclpp-test/alltoallv_test.cu | 17 ++-- 3 files changed, 48 insertions(+), 62 deletions(-) diff --git a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu index 032ac1aa..69e5d1bb 100644 --- a/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu +++ b/src/ext/collectives/alltoallv/alltoallv_fullmesh.cu @@ -71,7 +71,7 @@ void AlltoallvFullmesh::initialize(std::shared_ptr comm) { CommResult AlltoallvFullmesh::alltoallvKernelFunc( const std::shared_ptr ctx, const void* input, void* output, size_t inputSize, size_t outputSize, [[maybe_unused]] DataType dtype, cudaStream_t stream, - [[maybe_unused]] int nBlocks, [[maybe_unused]] int nThreadsPerBlock, + [[maybe_unused]] int nBlocks, int nThreadsPerBlock, const std::unordered_map& extras) { auto algoCtx = std::static_pointer_cast(ctx); @@ -94,14 +94,13 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc( const size_t* d_recvCounts = reinterpret_cast(it_recvCounts->second); const size_t* d_recvDispls = reinterpret_cast(it_recvDispls->second); + // Use maximum threads (1024) for best bandwidth utilization + const int threadsPerBlock = (nThreadsPerBlock > 0 && nThreadsPerBlock <= 1024) ? nThreadsPerBlock : 1024; + // Choose kernel based on world size if (worldSize <= 16) { - // Use parallel warp-based kernel for small world sizes - int nThreads = (worldSize - 1) * ALLTOALLV_WARP_SIZE; - if (nThreads < 32) nThreads = 32; - if (nThreads > 1024) nThreads = 1024; - - alltoallvKernel<<<1, nThreads, 0, stream>>>( + // Use high-throughput kernel with all threads + alltoallvKernel<<<1, threadsPerBlock, 0, stream>>>( algoCtx->memoryChannelDeviceHandles.get(), rank, worldSize, input, output, @@ -109,7 +108,7 @@ CommResult AlltoallvFullmesh::alltoallvKernelFunc( d_recvCounts, d_recvDispls); } else { // Use ring-based kernel for larger world sizes - alltoallvRingKernel<<<1, 32, 0, stream>>>( + alltoallvRingKernel<<<1, threadsPerBlock, 0, stream>>>( algoCtx->memoryChannelDeviceHandles.get(), rank, worldSize, input, output, diff --git a/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp b/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp index 7fc684ab..4db36555 100644 --- a/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp +++ b/src/ext/collectives/include/alltoallv/alltoallv_kernel.hpp @@ -17,10 +17,13 @@ namespace collective { #endif /** - * AllToAllV kernel implementation using parallel warp-based communication with MemoryChannel. + * High-performance AllToAllV kernel using maximum thread parallelism. * - * Each warp handles communication with one peer. Data is copied in parallel using all threads - * in the warp, which significantly improves throughput for large messages. + * Processes each peer sequentially but uses ALL block threads (1024) for each + * data transfer to maximize copy bandwidth. This provides much better performance + * than the warp-per-peer approach for large message sizes. + * + * Launch config: <<<1, 1024>>> for maximum bandwidth within a single block. * * @param memoryChannels Array of MemoryChannel handles for each peer (worldSize-1 channels) * @param rank Current rank @@ -43,64 +46,54 @@ __global__ void __launch_bounds__(1024) const size_t* recvCounts, const size_t* recvDispls) { int tid = threadIdx.x; + int nThreads = blockDim.x; int nPeers = worldSize - 1; - // Step 1: Copy local data (rank's own portion) using all threads + // Step 1: Copy local data using ALL threads for maximum bandwidth if (sendCounts[rank] > 0) { mscclpp::copy((char*)recvBuff + recvDispls[rank], (void*)((const char*)sendBuff + sendDispls[rank]), - sendCounts[rank], tid, blockDim.x); + sendCounts[rank], tid, nThreads); } __syncthreads(); - // Step 2: Each warp handles one peer for sending (parallel copy within warp) - int warpId = tid / ALLTOALLV_WARP_SIZE; - int laneId = tid % ALLTOALLV_WARP_SIZE; - - if (warpId < nPeers) { - // Determine which peer this warp handles - int peer = warpId < rank ? warpId : warpId + 1; - int chanIdx = warpId; + // Step 2: Process each peer sequentially, but use ALL threads for each transfer + // This maximizes bandwidth for each transfer compared to warp-per-peer approach + for (int peerIdx = 0; peerIdx < nPeers; peerIdx++) { + int peer = peerIdx < rank ? peerIdx : peerIdx + 1; + int chanIdx = peerIdx; if (sendCounts[peer] > 0) { - // Use parallel put with all threads in the warp - // targetOffset: recvDispls[rank] - where peer should receive our data - // originOffset: sendDispls[peer] - where our data for this peer starts + // Use all threads for maximum copy throughput memoryChannels[chanIdx].put( recvDispls[rank], // dst offset in peer's buffer sendDispls[peer], // src offset in our buffer sendCounts[peer], // size - laneId, // thread id within warp - ALLTOALLV_WARP_SIZE // number of threads + tid, // thread id + nThreads // total threads ); } - } - __syncthreads(); + __syncthreads(); - // Step 3: Signal completion to all peers - if (warpId < nPeers && laneId == 0) { - memoryChannels[warpId].signal(); - } - __syncthreads(); - - // Step 4: Wait for all incoming data - if (warpId < nPeers && laneId == 0) { - int peer = warpId < rank ? warpId : warpId + 1; - if (recvCounts[peer] > 0) { - memoryChannels[warpId].wait(); + // Only one thread signals per peer + if (tid == 0) { + memoryChannels[chanIdx].signal(); } + __syncthreads(); + + // Wait for incoming data from this peer + if (tid == 0 && recvCounts[peer] > 0) { + memoryChannels[chanIdx].wait(); + } + __syncthreads(); } - __syncthreads(); } /** - * Ring-based AllToAllV kernel for serialized communication with MemoryChannel. + * Ring-based AllToAllV kernel with maximum thread parallelism. * - * Uses step-by-step ring pattern to exchange data, sending to (rank+step) and - * receiving from (rank-step) in each step. All threads participate in the copy - * for better throughput. - * - * This kernel is more robust for larger world sizes. + * Uses step-by-step ring pattern with ALL threads for maximum bandwidth. + * Better for larger world sizes to avoid congestion. */ __global__ void __launch_bounds__(1024) alltoallvRingKernel(DeviceHandle* memoryChannels, @@ -113,12 +106,13 @@ __global__ void __launch_bounds__(1024) const size_t* recvCounts, const size_t* recvDispls) { int tid = threadIdx.x; + int nThreads = blockDim.x; - // Copy local data first using all threads + // Copy local data first using ALL threads if (sendCounts[rank] > 0) { mscclpp::copy((char*)recvBuff + recvDispls[rank], (void*)((const char*)sendBuff + sendDispls[rank]), - sendCounts[rank], tid, blockDim.x); + sendCounts[rank], tid, nThreads); } __syncthreads(); @@ -130,14 +124,14 @@ __global__ void __launch_bounds__(1024) int sendChanIdx = sendPeer < rank ? sendPeer : sendPeer - 1; int recvChanIdx = recvPeer < rank ? recvPeer : recvPeer - 1; - // Send data to sendPeer using all threads + // Send data to sendPeer using ALL threads if (sendCounts[sendPeer] > 0) { memoryChannels[sendChanIdx].put( recvDispls[rank], sendDispls[sendPeer], sendCounts[sendPeer], tid, - blockDim.x + nThreads ); } __syncthreads(); diff --git a/test/mscclpp-test/alltoallv_test.cu b/test/mscclpp-test/alltoallv_test.cu index bcffd224..43af8574 100644 --- a/test/mscclpp-test/alltoallv_test.cu +++ b/test/mscclpp-test/alltoallv_test.cu @@ -57,18 +57,11 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) { const int rank = args.rank; const int kernelNum = args.kernelNum; - // Reset device syncer - mscclpp::DeviceSyncer syncer = {}; - CUDATHROW(cudaMemcpyToSymbol(deviceSyncerV, &syncer, sizeof(mscclpp::DeviceSyncer))); + // Use maximum threads (1024) for best bandwidth utilization + const int nThreads = 1024; if (kernelNum == 0) { - // Use parallel warp-based kernel from library - int nThreads = (worldSize - 1) * 32; // One warp per peer -#if defined(__HIP_PLATFORM_AMD__) - nThreads = (worldSize - 1) * 64; -#endif - if (nThreads < 32) nThreads = 32; - if (nThreads > 1024) nThreads = 1024; + // Use high-throughput kernel with all threads participating in each transfer mscclpp::collective::alltoallvKernel<<<1, nThreads, 0, stream>>>( d_memoryChannels, rank, worldSize, @@ -76,8 +69,8 @@ void AllToAllVTestColl::runColl(const TestArgs& args, cudaStream_t stream) { d_sendCounts, d_sendDispls, d_recvCounts, d_recvDispls); } else if (kernelNum == 1) { - // Use ring-based kernel from library - mscclpp::collective::alltoallvRingKernel<<<1, 32, 0, stream>>>( + // Use ring-based kernel for larger world sizes + mscclpp::collective::alltoallvRingKernel<<<1, nThreads, 0, stream>>>( d_memoryChannels, rank, worldSize, localSendBuffV, localRecvBuffV,