From 307a4718884a59dd2acead9ca899a1667598b470 Mon Sep 17 00:00:00 2001 From: Binyang Li Date: Wed, 6 May 2026 21:37:09 +0000 Subject: [PATCH] Shorten verbose comments and use THROW in validateIpcDomainSpansWorld - Collapse the duplicated 3-line warp-strided-load comment in 5 kernels (allgather_fullmesh, allreduce_fullmesh, allreduce_packet, allreduce_nvls_zero_copy, allreduce_nvls_warp_pipeline) into a single one-line 'Peer count may exceed WARP_SIZE on MNNVL.' note. - Drop the algName parameter from validateIpcDomainSpansWorld; switch its 3 throws to use the THROW logger macro (LogSubsys::ALGO), which already captures file/line/function. Update the 3 callsites (nvls_block_pipeline, nvls_warp_pipeline, nvls_zero_copy) and trim the Doxygen comment accordingly. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../allgather/allgather_fullmesh.cu | 4 +--- .../allreduce/allreduce_fullmesh.cu | 4 +--- .../allreduce_nvls_block_pipeline.cu | 2 +- .../allreduce/allreduce_nvls_warp_pipeline.cu | 6 ++---- .../allreduce/allreduce_nvls_zero_copy.cu | 6 ++---- .../collectives/allreduce/allreduce_packet.cu | 4 +--- src/ext/collectives/collective_utils.cc | 20 +++++++++---------- .../collectives/include/collective_utils.hpp | 10 +++------- 8 files changed, 21 insertions(+), 35 deletions(-) diff --git a/src/ext/collectives/allgather/allgather_fullmesh.cu b/src/ext/collectives/allgather/allgather_fullmesh.cu index 8ce77fca..a4196c6c 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh.cu @@ -30,9 +30,7 @@ __global__ void __launch_bounds__(1024, 1) __shared__ DeviceHandle channels[MAX_IPC_DOMAIN_NRANKS - 1]; const int lid = threadIdx.x % WARP_SIZE; - // Each warp redundantly loads all entries (same value, benign race) so that - // every warp has the data its threads will read after __syncwarp(). Required - // when nPeer > WARP_SIZE (MNNVL/NVL72 scale). + // Peer count may exceed WARP_SIZE on MNNVL. for (int i = lid; i < nPeer; i += WARP_SIZE) { channels[i] = memoryChans[i]; } diff --git a/src/ext/collectives/allreduce/allreduce_fullmesh.cu b/src/ext/collectives/allreduce/allreduce_fullmesh.cu index f1d81560..ef7ecf74 100644 --- a/src/ext/collectives/allreduce/allreduce_fullmesh.cu +++ b/src/ext/collectives/allreduce/allreduce_fullmesh.cu @@ -52,9 +52,7 @@ __global__ void __launch_bounds__(512, 1) __shared__ DeviceHandle channels[MAX_IPC_DOMAIN_NRANKS - 1]; __shared__ DeviceHandle outChannels[MAX_IPC_DOMAIN_NRANKS - 1]; const int lid = threadIdx.x % WARP_SIZE; - // Each warp redundantly loads all entries (same value, benign race) so that - // every warp has the data its threads will read after __syncwarp(). Required - // when nPeer > WARP_SIZE (MNNVL/NVL72 scale). + // Peer count may exceed WARP_SIZE on MNNVL. for (int i = lid; i < nPeer; i += WARP_SIZE) { channels[i] = memoryChans[i]; outChannels[i] = memoryOutChans[i]; diff --git a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu index 4eeb0335..8c4a1e23 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu @@ -178,7 +178,7 @@ struct NvlsBlockPipelineAdapter { void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr comm) { nSwitchChannels_ = 8; - ipcDomainNranks_ = validateIpcDomainSpansWorld(comm, "AllreduceNvlsBlockPipeline"); + ipcDomainNranks_ = validateIpcDomainSpansWorld(comm); // Block-pipeline device-side semaphore indices grow as 6 * ipcDomainNranks (see kernel). if (6 * ipcDomainNranks_ > NUM_SEMAPHORES) { throw Error("AllreduceNvlsBlockPipeline: ipcDomainNranks " + std::to_string(ipcDomainNranks_) + diff --git a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu index 05e4f747..950c287b 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu @@ -59,9 +59,7 @@ __global__ void __launch_bounds__(1024, 1) auto memoryChans = memoryChannels + chanOffset; __shared__ DeviceHandle channels[(MAX_IPC_DOMAIN_NRANKS - 1) * 2]; const int lid = threadIdx.x % WARP_SIZE; - // Each warp redundantly loads all entries (same value, benign race) so that - // every warp has the data its threads will read after __syncwarp(). Required - // when nPeers*2 > WARP_SIZE (MNNVL scale). + // Peer count may exceed WARP_SIZE on MNNVL. for (int i = lid; i < nPeers * 2; i += WARP_SIZE) { channels[i] = memoryChans[i]; } @@ -144,7 +142,7 @@ struct NvlsWarpPipelineAdapter { void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr comm) { nSwitchChannels_ = NUM_NVLS_CONNECTION; - ipcDomainNranks_ = validateIpcDomainSpansWorld(comm, "AllreduceNvlsWarpPipeline"); + ipcDomainNranks_ = validateIpcDomainSpansWorld(comm); // The warp-pipeline kernel addresses 2 * nPeers entries per block in `memoryChannels`, // so per-peer base channel allocation must be at least `2 * nBlocks`. Default // nBlocks = 4 * ipcDomainNranks (see allreduceKernelFunc), so size accordingly. diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index 5d6fc4d3..6ab0cd63 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -45,9 +45,7 @@ __global__ void __launch_bounds__(1024, 1) auto memoryChans = memoryChannels + chanOffset; __shared__ mscclpp::DeviceHandle channels[MAX_IPC_DOMAIN_NRANKS - 1]; const int lid = threadIdx.x % WARP_SIZE; - // Each warp redundantly loads all entries (same value, benign race) so that - // every warp has the data its threads will read after __syncwarp(). Required - // when nPeers > WARP_SIZE (MNNVL/NVL72 → 71 peers). + // Peer count may exceed WARP_SIZE on MNNVL. for (int i = lid; i < ipcDomainNranks - 1; i += WARP_SIZE) { channels[i] = memoryChans[i]; } @@ -107,7 +105,7 @@ void AllreduceNvls::initialize(std::shared_ptr comm) { MSCCLPP_CUDATHROW(cudaGetDeviceProperties(&deviceProp, device)); computeCapabilityMajor_ = deviceProp.major; nSwitchChannels_ = 32; - validateIpcDomainSpansWorld(comm, "AllreduceNvls"); + validateIpcDomainSpansWorld(comm); this->conns_ = setupConnections(comm); // setup semaphores std::vector> memorySemaphores = diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index cc91370c..7bc9a85f 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -80,9 +80,7 @@ __global__ void __launch_bounds__(1024, 1) // Put channels into shared memory, read channel info from global memory is unexpectable slow. __shared__ mscclpp::DeviceHandle channels[MAX_IPC_DOMAIN_NRANKS - 1]; const int lid = tid % WARP_SIZE; - // Each warp redundantly loads all entries (same value, benign race) so that - // every warp has the data its threads will read after __syncwarp(). Required - // when nPeers > WARP_SIZE (MNNVL/NVL72 scale). + // Peer count may exceed WARP_SIZE on MNNVL. for (int i = lid; i < nPeers; i += WARP_SIZE) { channels[i] = memoryChannels[i]; } diff --git a/src/ext/collectives/collective_utils.cc b/src/ext/collectives/collective_utils.cc index 33b6ef77..e4eb7142 100644 --- a/src/ext/collectives/collective_utils.cc +++ b/src/ext/collectives/collective_utils.cc @@ -11,6 +11,8 @@ #include #include +#include "logger.hpp" + namespace mscclpp { namespace collective { std::vector setupRemoteMemories(std::shared_ptr comm, int rank, @@ -79,24 +81,22 @@ int getIpcDomainNranks(std::shared_ptr comm) { return comm->bootstrap()->getNranksPerNode(); } -int validateIpcDomainSpansWorld(std::shared_ptr comm, const char* algName) { +int validateIpcDomainSpansWorld(std::shared_ptr comm) { const int ipcDomainNranks = getIpcDomainNranks(comm); const int worldSize = comm->bootstrap()->getNranks(); const int rank = comm->bootstrap()->getRank(); if (ipcDomainNranks < 2 || ipcDomainNranks > MAX_IPC_DOMAIN_NRANKS) { - throw mscclpp::Error(std::string(algName) + ": ipcDomainNranks " + std::to_string(ipcDomainNranks) + - " is out of supported range [2, " + std::to_string(MAX_IPC_DOMAIN_NRANKS) + "]", - mscclpp::ErrorCode::InvalidUsage); + THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, "ipcDomainNranks ", + ipcDomainNranks, " is out of supported range [2, ", MAX_IPC_DOMAIN_NRANKS, "]"); } if (worldSize != ipcDomainNranks) { - throw mscclpp::Error(std::string(algName) + " requires worldSize == ipcDomainNranks (got worldSize=" + - std::to_string(worldSize) + ", ipcDomainNranks=" + std::to_string(ipcDomainNranks) + ")", - mscclpp::ErrorCode::InvalidUsage); + THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, + "requires worldSize == ipcDomainNranks (got worldSize=", worldSize, ", ipcDomainNranks=", ipcDomainNranks, + ")"); } if (rank < 0 || rank >= ipcDomainNranks) { - throw mscclpp::Error(std::string(algName) + ": rank " + std::to_string(rank) + " out of [0, " + - std::to_string(ipcDomainNranks) + ")", - mscclpp::ErrorCode::InvalidUsage); + THROW(mscclpp::LogSubsys::ALGO, mscclpp::Error, mscclpp::ErrorCode::InvalidUsage, "rank ", rank, " out of [0, ", + ipcDomainNranks, ")"); } return ipcDomainNranks; } diff --git a/src/ext/collectives/include/collective_utils.hpp b/src/ext/collectives/include/collective_utils.hpp index 892df3b1..6b0c6ab4 100644 --- a/src/ext/collectives/include/collective_utils.hpp +++ b/src/ext/collectives/include/collective_utils.hpp @@ -60,13 +60,9 @@ int getIpcDomainNranks(std::shared_ptr comm); /// Validates that the IPC domain spans the whole communicator and that the local rank fits within /// the supported `[2, MAX_IPC_DOMAIN_NRANKS]` range. Used by NVLS allreduce algorithms whose -/// multicast group spans the whole communicator (see `setupNvlsConnections`) and whose kernels -/// use the global rank to compute per-rank offsets while sizing per-rank work by -/// `ipcDomainNranks`. These assumptions only hold when the IPC-reachable peer group is exactly -/// the whole communicator (e.g. a fully populated MNNVL fabric like NVL72). Returns the validated -/// `ipcDomainNranks`; throws `Error(InvalidUsage)` on violation. `algName` is used as a prefix -/// in error messages. -int validateIpcDomainSpansWorld(std::shared_ptr comm, const char* algName); +/// multicast group spans the whole communicator. Returns the validated `ipcDomainNranks`; throws +/// `Error(InvalidUsage)` on violation. +int validateIpcDomainSpansWorld(std::shared_ptr comm); std::shared_ptr> setupMemoryChannelDeviceHandles( const std::vector& memoryChannels);