From 9e177b388c1ace99be176c9553512849e2df6ae7 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 20 May 2026 16:49:49 -0700 Subject: [PATCH 1/2] remove useless sync (#809) --- src/ext/collectives/allreduce/allreduce_allpair_packet.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu index faef5459..49058f59 100644 --- a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu @@ -63,7 +63,6 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand } dst[idx] = mscclpp::downcastVector(acc); } - __syncthreads(); if (threadIdx.x == 0) { ((uint32_t*)flags)[blockIdx.x] = flag + 1; } From 08ee18be64248ebdbd94bc758e2cbb59e3ac9bf5 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Fri, 22 May 2026 09:18:41 -0700 Subject: [PATCH 2/2] Add check to filter invalid nblock/nthread candidates (#811) Add check for invalid nblock/nthread candidate --- .../customized_comm_with_tuning.py | 4 +-- .../allgather/allgather_fullmesh.cu | 30 ++++++++++++------- .../allgather/allgather_fullmesh_2.cu | 26 ++++++++++++++-- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index 060a0097..b96087c2 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -70,12 +70,12 @@ class CustomizedComm: _TUNE_N_WARMUP = 5 _TUNE_N_GRAPH_LAUNCHES = 10 _TUNE_N_OPS_PER_GRAPH = 100 - _CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 64, 128] + _CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 56, 64, 128] _CANDIDATE_NTHREADS = [512, 768, 1024] _NBLOCKS_LIMIT = { "default_allreduce_nvls_packet": 16, "default_allreduce_packet": 56, - "default_allreduce_allpair_packet": 56, + "default_allreduce_allpair_packet": 64, "default_allreduce_fullmesh": 64, "default_allgather_fullmesh2": 32, } diff --git a/src/ext/collectives/allgather/allgather_fullmesh.cu b/src/ext/collectives/allgather/allgather_fullmesh.cu index fb51a342..d1b4e731 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh.cu @@ -8,6 +8,11 @@ namespace mscclpp { namespace collective { +namespace { +constexpr int kMaxBlocks = 56; +constexpr int kMaxThreadsPerBlock = 1024; +} // namespace + template __global__ void __launch_bounds__(1024, 1) allgatherFullmesh(void* buff, void* scratch, void* resultBuff, DeviceHandle* memoryChannels, @@ -116,12 +121,19 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr ct int rank = ctx->rank; const size_t nElem = inputSize / sizeof(int); std::pair numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; - if (numBlocksAndThreads.first > 56) { - WARN("AllgatherFullmesh: number of blocks exceeds maximum supported blocks, which is 56"); - return mscclpp::CommResult::CommInvalidArgument; - } if (numBlocksAndThreads.first == 0 || numBlocksAndThreads.second == 0) { - numBlocksAndThreads = {56, 1024}; + numBlocksAndThreads = {kMaxBlocks, kMaxThreadsPerBlock}; + } + if (numBlocksAndThreads.first > kMaxBlocks || numBlocksAndThreads.second > kMaxThreadsPerBlock) { + WARN( + "AllgatherFullmesh: number of blocks must be no more than %d and threads per block must be no more than %d; " + "got nBlocks=%d, nThreadsPerBlock=%d", + kMaxBlocks, kMaxThreadsPerBlock, numBlocksAndThreads.first, numBlocksAndThreads.second); + return CommResult::CommInvalidArgument; + } + if (numBlocksAndThreads.second % WARP_SIZE != 0) { + WARN("AllgatherFullmesh: threads per block must be a multiple of warp size %d", WARP_SIZE); + return CommResult::CommInvalidArgument; } if ((char*)input == (char*)output + rank * inputSize) { allgatherFullmesh<<>>( @@ -142,15 +154,13 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr ct std::shared_ptr AllgatherFullmesh::initAllgatherContext(std::shared_ptr comm, const void* input, void*, size_t inputSize, DataType) { - constexpr int nChannelsPerConnection = 56; - auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); // setup semaphores - ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection); + ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, kMaxBlocks); // register the memory for the broadcast operation RegisteredMemory localMemory = comm->registerMemory((void*)input, inputSize, Transport::CudaIpc); @@ -159,7 +169,7 @@ std::shared_ptr AllgatherFullmesh::initAllgatherContext(std::shared_ptrmemoryChannels = - setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection); + setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, kMaxBlocks); ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels); // keep registered memories reference @@ -196,4 +206,4 @@ std::shared_ptr AllgatherFullmesh::build() { }); } } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/allgather/allgather_fullmesh_2.cu b/src/ext/collectives/allgather/allgather_fullmesh_2.cu index 9d169d68..89581822 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh_2.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh_2.cu @@ -18,7 +18,11 @@ __global__ void __launch_bounds__(1024, 1) const size_t lid = tid % WARP_SIZE; const size_t wid = tid / WARP_SIZE; - const size_t nThread = blockDim.x * gridDim.x; + // Round down to multiple of warp size + const size_t nThread = (blockDim.x * gridDim.x) / WARP_SIZE * WARP_SIZE; + if (tid >= nThread) { + return; + } const size_t nWarp = nThread / WARP_SIZE; const size_t nPeer = nRanksPerNode - 1; const size_t chanOffset = nPeer * blockIdx.x; @@ -135,6 +139,24 @@ CommResult AllgatherFullmesh2::allgatherKernelFunc(const std::shared_ptr c numBlocksAndThreads.first = 35; } } + const int nPeer = ctx->nRanksPerNode - 1; + const int nWarp = numBlocksAndThreads.first * numBlocksAndThreads.second / WARP_SIZE; + if (numBlocksAndThreads.first > nChannelsPerConnection_ || numBlocksAndThreads.first <= 0 || + numBlocksAndThreads.second <= 0) { + WARN( + "AllgatherFullmesh2: number of blocks must be a positive multiple of peer count and no more than %d, threads " + "per block must be positive; got nBlocks=%d, nThreadsPerBlock=%d, nPeers=%d", + nChannelsPerConnection_, numBlocksAndThreads.first, numBlocksAndThreads.second, nPeer); + return CommResult::CommInvalidArgument; + } + if (nWarp < nPeer) { + WARN( + "AllgatherFullmesh2: total number of warps must be no less than peer count; got nBlocks=%d, " + "nThreadsPerBlock=%d, " + "nPeers=%d", + numBlocksAndThreads.first, numBlocksAndThreads.second, nPeer); + return CommResult::CommInvalidArgument; + } size_t channelOutOffset = *static_cast(ctx->extras["channel_out_offset"].get()); if ((char*)input == (char*)output + rank * inputSize) { @@ -226,4 +248,4 @@ std::shared_ptr AllgatherFullmesh2::build() { } } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp